You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

293 lines
10 KiB

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BERT Token数量统计与可视化
统计sentence1和最后一个sentence2的token数量分布
"""
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import Counter
from transformers import AutoTokenizer
import warnings
# 忽略transformers的警告
warnings.filterwarnings("ignore")
# 设置matplotlib后端,避免显示问题
plt.switch_backend('Agg')
def load_sentence_pairs(file_path):
"""加载句子对数据"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"成功加载 {len(data)} 个句子对")
return data
except FileNotFoundError:
print(f"错误:找不到文件 {file_path}")
return None
except json.JSONDecodeError:
print(f"错误:JSON文件格式错误")
return None
except Exception as e:
print(f"加载文件时发生错误:{e}")
return None
def initialize_tokenizer(model_name="bert-base-chinese"):
"""初始化BERT tokenizer"""
try:
print(f"初始化 {model_name} tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Tokenizer初始化成功")
return tokenizer
except Exception as e:
print(f"初始化tokenizer失败:{e}")
print("尝试使用备用tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
print("成功使用多语言BERT tokenizer")
return tokenizer
except Exception as e2:
print(f"备用tokenizer也失败:{e2}")
return None
def count_bert_tokens(text, tokenizer):
"""计算文本的BERT token数量(不包含特殊token)"""
if not text or text.strip() == "":
return 0
try:
# 使用tokenizer编码文本,不添加特殊token
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
except Exception as e:
print(f"计算token时出错:{e}")
return 0
def get_token_range_label(token_count):
"""根据token数量获取对应的区间标签"""
range_start = (token_count // 100) * 100
range_end = range_start + 99
return f"{range_start}-{range_end}"
def analyze_token_distribution(sentence_pairs, tokenizer):
"""分析token分布"""
print("\n开始分析token分布...")
# 收集所有sentence1的token数量和对应的source_id
sentence1_tokens = []
token_details = [] # 存储详细信息:(token_count, source_id, sentence_type, sentence_text)
for pair in sentence_pairs:
sentence1 = pair.get('sentence1', '')
source_id = pair.get('source_id', 'unknown')
token_count = count_bert_tokens(sentence1, tokenizer)
sentence1_tokens.append(token_count)
token_details.append((token_count, source_id, 'sentence1', sentence1))
# 获取最后一个句子对的sentence2
last_sentence2_tokens = 0
if sentence_pairs:
last_pair = sentence_pairs[-1]
last_sentence2 = last_pair.get('sentence2', '')
last_source_id = last_pair.get('source_id', 'unknown')
last_sentence2_tokens = count_bert_tokens(last_sentence2, tokenizer)
if last_sentence2_tokens > 0:
token_details.append((last_sentence2_tokens, last_source_id, 'sentence2', last_sentence2))
print(f"处理了 {len(sentence1_tokens)} 个sentence1")
print(f"最后一个sentence2的token数量: {last_sentence2_tokens}")
return sentence1_tokens, last_sentence2_tokens, token_details
def create_token_distribution_chart(sentence1_tokens, last_sentence2_tokens):
"""创建token分布柱状图"""
print("\n生成token分布图...")
# 合并所有需要统计的token数量
all_tokens = sentence1_tokens + [last_sentence2_tokens] if last_sentence2_tokens > 0 else sentence1_tokens
# 计算最大token数量以确定区间范围
max_tokens = max(all_tokens) if all_tokens else 0
max_range = ((max_tokens // 100) + 1) * 100
# 创建区间
ranges = []
range_labels = []
for i in range(0, max_range, 100):
ranges.append((i, i + 99))
range_labels.append(f"{i}-{i + 99}")
# 统计每个区间的句子数量
range_counts = [0] * len(ranges)
for token_count in all_tokens:
range_index = token_count // 100
if range_index < len(range_counts):
range_counts[range_index] += 1
# 创建图表
plt.figure(figsize=(12, 8))
# 创建柱状图
bars = plt.bar(range_labels, range_counts, color='skyblue', edgecolor='navy', alpha=0.7)
# 设置图表属性
plt.title('BERT Token Count Distribution', fontsize=16, fontweight='bold')
plt.xlabel('Token Count Range', fontsize=12)
plt.ylabel('Number of Sentences', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)
# 在柱子上添加数值标签
for bar, count in zip(bars, range_counts):
if count > 0: # 只在有数据的柱子上显示标签
plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
str(count), ha='center', va='bottom', fontsize=10)
# 调整布局
plt.tight_layout()
# 显示统计信息
total_sentences = len(all_tokens)
avg_tokens = np.mean(all_tokens) if all_tokens else 0
median_tokens = np.median(all_tokens) if all_tokens else 0
# 在图表上添加统计信息文本框
stats_text = f'Total Sentences: {total_sentences}\n'
stats_text += f'Average Tokens: {avg_tokens:.1f}\n'
stats_text += f'Median Tokens: {median_tokens:.1f}\n'
stats_text += f'Max Tokens: {max_tokens}'
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
verticalalignment='top', fontsize=10)
return plt
def find_high_token_sentences(token_details, threshold=300):
"""查找token数量超过阈值的句子"""
print(f"\n=== Token数量超过{threshold}的句子 ===")
high_token_sentences = [(count, source_id, sentence_type, sentence)
for count, source_id, sentence_type, sentence in token_details
if count > threshold]
if not high_token_sentences:
print(f"没有找到token数量超过{threshold}的句子")
return []
# 按token数量降序排列
high_token_sentences.sort(key=lambda x: x[0], reverse=True)
print(f"找到 {len(high_token_sentences)} 个token数量超过{threshold}的句子:")
print("-" * 80)
for i, (token_count, source_id, sentence_type, sentence) in enumerate(high_token_sentences, 1):
print(f"{i}. Source ID: {source_id}")
print(f" Type: {sentence_type}")
print(f" Token Count: {token_count}")
print(f" Content: {sentence[:100]}{'...' if len(sentence) > 100 else ''}")
print("-" * 80)
# 保存到CSV文件
import pandas as pd
df_high_tokens = pd.DataFrame(high_token_sentences,
columns=['token_count', 'source_id', 'sentence_type', 'sentence_text'])
output_file = f'high_token_sentences_over_{threshold}.csv'
df_high_tokens.to_csv(output_file, index=False, encoding='utf-8-sig')
print(f"详细信息已保存到: {output_file}")
return high_token_sentences
"""打印详细统计信息"""
print("\n=== 详细统计信息 ===")
all_tokens = sentence1_tokens + [last_sentence2_tokens] if last_sentence2_tokens > 0 else sentence1_tokens
if not all_tokens:
print("没有数据可统计")
return
print(f"Sentence1总数: {len(sentence1_tokens)}")
print(f"Last Sentence2: {'已包含' if last_sentence2_tokens > 0 else '无数据'}")
print(f"总句子数: {len(all_tokens)}")
print(f"平均token数: {np.mean(all_tokens):.2f}")
print(f"中位数token数: {np.median(all_tokens):.2f}")
print(f"最小token数: {min(all_tokens)}")
print(f"最大token数: {max(all_tokens)}")
print(f"标准差: {np.std(all_tokens):.2f}")
# 按区间统计
print("\n=== 区间分布 ===")
max_tokens = max(all_tokens)
max_range = ((max_tokens // 100) + 1) * 100
for i in range(0, max_range, 100):
count = sum(1 for x in all_tokens if i <= x < i + 100)
if count > 0:
percentage = (count / len(all_tokens)) * 100
print(f"{i}-{i + 99} tokens: {count} 句子 ({percentage:.1f}%)")
def main():
"""主函数"""
# 文件路径
input_file = 'segmentation_results_from_7_retried.json'
# 1. 加载数据
sentence_pairs = load_sentence_pairs(input_file)
if sentence_pairs is None:
return
# 2. 初始化tokenizer
tokenizer = initialize_tokenizer("bert-base-chinese")
if tokenizer is None:
print("无法初始化tokenizer,程序退出")
return
# 3. 分析token分布
sentence1_tokens, last_sentence2_tokens, token_details = analyze_token_distribution(sentence_pairs, tokenizer)
if not sentence1_tokens:
print("没有找到有效的句子数据")
return
# 4. 查找高token数量的句子
high_token_sentences = find_high_token_sentences(token_details, threshold=300)
# 5. 打印详细统计
# print_detailed_statistics(sentence1_tokens, last_sentence2_tokens)
# 6. 创建可视化图表
plt = create_token_distribution_chart(sentence1_tokens, last_sentence2_tokens)
# 7. 保存和显示图表
try:
output_file = 'bert_token_distribution.png'
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\n图表已保存为: {output_file}")
plt.show()
except Exception as e:
print(f"保存或显示图表时出错: {e}")
# 尝试不显示图表,只保存
try:
plt.savefig('bert_token_distribution.png', dpi=300, bbox_inches='tight')
print("图表已保存,但无法显示")
except Exception as e2:
print(f"保存图表也失败: {e2}")
print("\n分析完成!")
if __name__ == "__main__":
main()