#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ BERT Token数量统计与可视化 统计sentence1和最后一个sentence2的token数量分布 """ import json import matplotlib.pyplot as plt import numpy as np import pandas as pd from collections import Counter from transformers import AutoTokenizer import warnings # 忽略transformers的警告 warnings.filterwarnings("ignore") # 设置matplotlib后端,避免显示问题 plt.switch_backend('Agg') def load_sentence_pairs(file_path): """加载句子对数据""" try: with open(file_path, 'r', encoding='utf-8') as f: data = json.load(f) print(f"成功加载 {len(data)} 个句子对") return data except FileNotFoundError: print(f"错误:找不到文件 {file_path}") return None except json.JSONDecodeError: print(f"错误:JSON文件格式错误") return None except Exception as e: print(f"加载文件时发生错误:{e}") return None def initialize_tokenizer(model_name="bert-base-chinese"): """初始化BERT tokenizer""" try: print(f"初始化 {model_name} tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name) print("Tokenizer初始化成功") return tokenizer except Exception as e: print(f"初始化tokenizer失败:{e}") print("尝试使用备用tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased") print("成功使用多语言BERT tokenizer") return tokenizer except Exception as e2: print(f"备用tokenizer也失败:{e2}") return None def count_bert_tokens(text, tokenizer): """计算文本的BERT token数量(不包含特殊token)""" if not text or text.strip() == "": return 0 try: # 使用tokenizer编码文本,不添加特殊token tokens = tokenizer.encode(text, add_special_tokens=False) return len(tokens) except Exception as e: print(f"计算token时出错:{e}") return 0 def get_token_range_label(token_count): """根据token数量获取对应的区间标签""" range_start = (token_count // 100) * 100 range_end = range_start + 99 return f"{range_start}-{range_end}" def analyze_token_distribution(sentence_pairs, tokenizer): """分析token分布""" print("\n开始分析token分布...") # 收集所有sentence1的token数量和对应的source_id sentence1_tokens = [] token_details = [] # 存储详细信息:(token_count, source_id, sentence_type, sentence_text) for pair in sentence_pairs: sentence1 = pair.get('sentence1', '') source_id = pair.get('source_id', 'unknown') token_count = count_bert_tokens(sentence1, tokenizer) sentence1_tokens.append(token_count) token_details.append((token_count, source_id, 'sentence1', sentence1)) # 获取最后一个句子对的sentence2 last_sentence2_tokens = 0 if sentence_pairs: last_pair = sentence_pairs[-1] last_sentence2 = last_pair.get('sentence2', '') last_source_id = last_pair.get('source_id', 'unknown') last_sentence2_tokens = count_bert_tokens(last_sentence2, tokenizer) if last_sentence2_tokens > 0: token_details.append((last_sentence2_tokens, last_source_id, 'sentence2', last_sentence2)) print(f"处理了 {len(sentence1_tokens)} 个sentence1") print(f"最后一个sentence2的token数量: {last_sentence2_tokens}") return sentence1_tokens, last_sentence2_tokens, token_details def create_token_distribution_chart(sentence1_tokens, last_sentence2_tokens): """创建token分布柱状图""" print("\n生成token分布图...") # 合并所有需要统计的token数量 all_tokens = sentence1_tokens + [last_sentence2_tokens] if last_sentence2_tokens > 0 else sentence1_tokens # 计算最大token数量以确定区间范围 max_tokens = max(all_tokens) if all_tokens else 0 max_range = ((max_tokens // 100) + 1) * 100 # 创建区间 ranges = [] range_labels = [] for i in range(0, max_range, 100): ranges.append((i, i + 99)) range_labels.append(f"{i}-{i + 99}") # 统计每个区间的句子数量 range_counts = [0] * len(ranges) for token_count in all_tokens: range_index = token_count // 100 if range_index < len(range_counts): range_counts[range_index] += 1 # 创建图表 plt.figure(figsize=(12, 8)) # 创建柱状图 bars = plt.bar(range_labels, range_counts, color='skyblue', edgecolor='navy', alpha=0.7) # 设置图表属性 plt.title('BERT Token Count Distribution', fontsize=16, fontweight='bold') plt.xlabel('Token Count Range', fontsize=12) plt.ylabel('Number of Sentences', fontsize=12) plt.xticks(rotation=45, ha='right') plt.grid(axis='y', alpha=0.3) # 在柱子上添加数值标签 for bar, count in zip(bars, range_counts): if count > 0: # 只在有数据的柱子上显示标签 plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, str(count), ha='center', va='bottom', fontsize=10) # 调整布局 plt.tight_layout() # 显示统计信息 total_sentences = len(all_tokens) avg_tokens = np.mean(all_tokens) if all_tokens else 0 median_tokens = np.median(all_tokens) if all_tokens else 0 # 在图表上添加统计信息文本框 stats_text = f'Total Sentences: {total_sentences}\n' stats_text += f'Average Tokens: {avg_tokens:.1f}\n' stats_text += f'Median Tokens: {median_tokens:.1f}\n' stats_text += f'Max Tokens: {max_tokens}' plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), verticalalignment='top', fontsize=10) return plt def find_high_token_sentences(token_details, threshold=300): """查找token数量超过阈值的句子""" print(f"\n=== Token数量超过{threshold}的句子 ===") high_token_sentences = [(count, source_id, sentence_type, sentence) for count, source_id, sentence_type, sentence in token_details if count > threshold] if not high_token_sentences: print(f"没有找到token数量超过{threshold}的句子") return [] # 按token数量降序排列 high_token_sentences.sort(key=lambda x: x[0], reverse=True) print(f"找到 {len(high_token_sentences)} 个token数量超过{threshold}的句子:") print("-" * 80) for i, (token_count, source_id, sentence_type, sentence) in enumerate(high_token_sentences, 1): print(f"{i}. Source ID: {source_id}") print(f" Type: {sentence_type}") print(f" Token Count: {token_count}") print(f" Content: {sentence[:100]}{'...' if len(sentence) > 100 else ''}") print("-" * 80) # 保存到CSV文件 import pandas as pd df_high_tokens = pd.DataFrame(high_token_sentences, columns=['token_count', 'source_id', 'sentence_type', 'sentence_text']) output_file = f'high_token_sentences_over_{threshold}.csv' df_high_tokens.to_csv(output_file, index=False, encoding='utf-8-sig') print(f"详细信息已保存到: {output_file}") return high_token_sentences """打印详细统计信息""" print("\n=== 详细统计信息 ===") all_tokens = sentence1_tokens + [last_sentence2_tokens] if last_sentence2_tokens > 0 else sentence1_tokens if not all_tokens: print("没有数据可统计") return print(f"Sentence1总数: {len(sentence1_tokens)}") print(f"Last Sentence2: {'已包含' if last_sentence2_tokens > 0 else '无数据'}") print(f"总句子数: {len(all_tokens)}") print(f"平均token数: {np.mean(all_tokens):.2f}") print(f"中位数token数: {np.median(all_tokens):.2f}") print(f"最小token数: {min(all_tokens)}") print(f"最大token数: {max(all_tokens)}") print(f"标准差: {np.std(all_tokens):.2f}") # 按区间统计 print("\n=== 区间分布 ===") max_tokens = max(all_tokens) max_range = ((max_tokens // 100) + 1) * 100 for i in range(0, max_range, 100): count = sum(1 for x in all_tokens if i <= x < i + 100) if count > 0: percentage = (count / len(all_tokens)) * 100 print(f"{i}-{i + 99} tokens: {count} 句子 ({percentage:.1f}%)") def main(): """主函数""" # 文件路径 input_file = 'segmentation_results_from_7_retried.json' # 1. 加载数据 sentence_pairs = load_sentence_pairs(input_file) if sentence_pairs is None: return # 2. 初始化tokenizer tokenizer = initialize_tokenizer("bert-base-chinese") if tokenizer is None: print("无法初始化tokenizer,程序退出") return # 3. 分析token分布 sentence1_tokens, last_sentence2_tokens, token_details = analyze_token_distribution(sentence_pairs, tokenizer) if not sentence1_tokens: print("没有找到有效的句子数据") return # 4. 查找高token数量的句子 high_token_sentences = find_high_token_sentences(token_details, threshold=300) # 5. 打印详细统计 # print_detailed_statistics(sentence1_tokens, last_sentence2_tokens) # 6. 创建可视化图表 plt = create_token_distribution_chart(sentence1_tokens, last_sentence2_tokens) # 7. 保存和显示图表 try: output_file = 'bert_token_distribution.png' plt.savefig(output_file, dpi=300, bbox_inches='tight') print(f"\n图表已保存为: {output_file}") plt.show() except Exception as e: print(f"保存或显示图表时出错: {e}") # 尝试不显示图表,只保存 try: plt.savefig('bert_token_distribution.png', dpi=300, bbox_inches='tight') print("图表已保存,但无法显示") except Exception as e2: print(f"保存图表也失败: {e2}") print("\n分析完成!") if __name__ == "__main__": main()