You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
293 lines
10 KiB
293 lines
10 KiB
#!/usr/bin/env python3 |
|
# -*- coding: utf-8 -*- |
|
""" |
|
BERT Token数量统计与可视化 |
|
统计sentence1和最后一个sentence2的token数量分布 |
|
""" |
|
|
|
import json |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
from collections import Counter |
|
from transformers import AutoTokenizer |
|
import warnings |
|
|
|
# 忽略transformers的警告 |
|
warnings.filterwarnings("ignore") |
|
|
|
# 设置matplotlib后端,避免显示问题 |
|
plt.switch_backend('Agg') |
|
|
|
|
|
def load_sentence_pairs(file_path): |
|
"""加载句子对数据""" |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
print(f"成功加载 {len(data)} 个句子对") |
|
return data |
|
except FileNotFoundError: |
|
print(f"错误:找不到文件 {file_path}") |
|
return None |
|
except json.JSONDecodeError: |
|
print(f"错误:JSON文件格式错误") |
|
return None |
|
except Exception as e: |
|
print(f"加载文件时发生错误:{e}") |
|
return None |
|
|
|
|
|
def initialize_tokenizer(model_name="bert-base-chinese"): |
|
"""初始化BERT tokenizer""" |
|
try: |
|
print(f"初始化 {model_name} tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
print("Tokenizer初始化成功") |
|
return tokenizer |
|
except Exception as e: |
|
print(f"初始化tokenizer失败:{e}") |
|
print("尝试使用备用tokenizer...") |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased") |
|
print("成功使用多语言BERT tokenizer") |
|
return tokenizer |
|
except Exception as e2: |
|
print(f"备用tokenizer也失败:{e2}") |
|
return None |
|
|
|
|
|
def count_bert_tokens(text, tokenizer): |
|
"""计算文本的BERT token数量(不包含特殊token)""" |
|
if not text or text.strip() == "": |
|
return 0 |
|
|
|
try: |
|
# 使用tokenizer编码文本,不添加特殊token |
|
tokens = tokenizer.encode(text, add_special_tokens=False) |
|
return len(tokens) |
|
except Exception as e: |
|
print(f"计算token时出错:{e}") |
|
return 0 |
|
|
|
|
|
def get_token_range_label(token_count): |
|
"""根据token数量获取对应的区间标签""" |
|
range_start = (token_count // 100) * 100 |
|
range_end = range_start + 99 |
|
return f"{range_start}-{range_end}" |
|
|
|
|
|
def analyze_token_distribution(sentence_pairs, tokenizer): |
|
"""分析token分布""" |
|
print("\n开始分析token分布...") |
|
|
|
# 收集所有sentence1的token数量和对应的source_id |
|
sentence1_tokens = [] |
|
token_details = [] # 存储详细信息:(token_count, source_id, sentence_type, sentence_text) |
|
|
|
for pair in sentence_pairs: |
|
sentence1 = pair.get('sentence1', '') |
|
source_id = pair.get('source_id', 'unknown') |
|
token_count = count_bert_tokens(sentence1, tokenizer) |
|
sentence1_tokens.append(token_count) |
|
token_details.append((token_count, source_id, 'sentence1', sentence1)) |
|
|
|
# 获取最后一个句子对的sentence2 |
|
last_sentence2_tokens = 0 |
|
if sentence_pairs: |
|
last_pair = sentence_pairs[-1] |
|
last_sentence2 = last_pair.get('sentence2', '') |
|
last_source_id = last_pair.get('source_id', 'unknown') |
|
last_sentence2_tokens = count_bert_tokens(last_sentence2, tokenizer) |
|
if last_sentence2_tokens > 0: |
|
token_details.append((last_sentence2_tokens, last_source_id, 'sentence2', last_sentence2)) |
|
|
|
print(f"处理了 {len(sentence1_tokens)} 个sentence1") |
|
print(f"最后一个sentence2的token数量: {last_sentence2_tokens}") |
|
|
|
return sentence1_tokens, last_sentence2_tokens, token_details |
|
|
|
|
|
def create_token_distribution_chart(sentence1_tokens, last_sentence2_tokens): |
|
"""创建token分布柱状图""" |
|
print("\n生成token分布图...") |
|
|
|
# 合并所有需要统计的token数量 |
|
all_tokens = sentence1_tokens + [last_sentence2_tokens] if last_sentence2_tokens > 0 else sentence1_tokens |
|
|
|
# 计算最大token数量以确定区间范围 |
|
max_tokens = max(all_tokens) if all_tokens else 0 |
|
max_range = ((max_tokens // 100) + 1) * 100 |
|
|
|
# 创建区间 |
|
ranges = [] |
|
range_labels = [] |
|
for i in range(0, max_range, 100): |
|
ranges.append((i, i + 99)) |
|
range_labels.append(f"{i}-{i + 99}") |
|
|
|
# 统计每个区间的句子数量 |
|
range_counts = [0] * len(ranges) |
|
|
|
for token_count in all_tokens: |
|
range_index = token_count // 100 |
|
if range_index < len(range_counts): |
|
range_counts[range_index] += 1 |
|
|
|
# 创建图表 |
|
plt.figure(figsize=(12, 8)) |
|
|
|
# 创建柱状图 |
|
bars = plt.bar(range_labels, range_counts, color='skyblue', edgecolor='navy', alpha=0.7) |
|
|
|
# 设置图表属性 |
|
plt.title('BERT Token Count Distribution', fontsize=16, fontweight='bold') |
|
plt.xlabel('Token Count Range', fontsize=12) |
|
plt.ylabel('Number of Sentences', fontsize=12) |
|
plt.xticks(rotation=45, ha='right') |
|
plt.grid(axis='y', alpha=0.3) |
|
|
|
# 在柱子上添加数值标签 |
|
for bar, count in zip(bars, range_counts): |
|
if count > 0: # 只在有数据的柱子上显示标签 |
|
plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, |
|
str(count), ha='center', va='bottom', fontsize=10) |
|
|
|
# 调整布局 |
|
plt.tight_layout() |
|
|
|
# 显示统计信息 |
|
total_sentences = len(all_tokens) |
|
avg_tokens = np.mean(all_tokens) if all_tokens else 0 |
|
median_tokens = np.median(all_tokens) if all_tokens else 0 |
|
|
|
# 在图表上添加统计信息文本框 |
|
stats_text = f'Total Sentences: {total_sentences}\n' |
|
stats_text += f'Average Tokens: {avg_tokens:.1f}\n' |
|
stats_text += f'Median Tokens: {median_tokens:.1f}\n' |
|
stats_text += f'Max Tokens: {max_tokens}' |
|
|
|
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes, |
|
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8), |
|
verticalalignment='top', fontsize=10) |
|
|
|
return plt |
|
|
|
|
|
def find_high_token_sentences(token_details, threshold=300): |
|
"""查找token数量超过阈值的句子""" |
|
print(f"\n=== Token数量超过{threshold}的句子 ===") |
|
|
|
high_token_sentences = [(count, source_id, sentence_type, sentence) |
|
for count, source_id, sentence_type, sentence in token_details |
|
if count > threshold] |
|
|
|
if not high_token_sentences: |
|
print(f"没有找到token数量超过{threshold}的句子") |
|
return [] |
|
|
|
# 按token数量降序排列 |
|
high_token_sentences.sort(key=lambda x: x[0], reverse=True) |
|
|
|
print(f"找到 {len(high_token_sentences)} 个token数量超过{threshold}的句子:") |
|
print("-" * 80) |
|
|
|
for i, (token_count, source_id, sentence_type, sentence) in enumerate(high_token_sentences, 1): |
|
print(f"{i}. Source ID: {source_id}") |
|
print(f" Type: {sentence_type}") |
|
print(f" Token Count: {token_count}") |
|
print(f" Content: {sentence[:100]}{'...' if len(sentence) > 100 else ''}") |
|
print("-" * 80) |
|
|
|
# 保存到CSV文件 |
|
import pandas as pd |
|
df_high_tokens = pd.DataFrame(high_token_sentences, |
|
columns=['token_count', 'source_id', 'sentence_type', 'sentence_text']) |
|
output_file = f'high_token_sentences_over_{threshold}.csv' |
|
df_high_tokens.to_csv(output_file, index=False, encoding='utf-8-sig') |
|
print(f"详细信息已保存到: {output_file}") |
|
|
|
return high_token_sentences |
|
"""打印详细统计信息""" |
|
print("\n=== 详细统计信息 ===") |
|
|
|
all_tokens = sentence1_tokens + [last_sentence2_tokens] if last_sentence2_tokens > 0 else sentence1_tokens |
|
|
|
if not all_tokens: |
|
print("没有数据可统计") |
|
return |
|
|
|
print(f"Sentence1总数: {len(sentence1_tokens)}") |
|
print(f"Last Sentence2: {'已包含' if last_sentence2_tokens > 0 else '无数据'}") |
|
print(f"总句子数: {len(all_tokens)}") |
|
print(f"平均token数: {np.mean(all_tokens):.2f}") |
|
print(f"中位数token数: {np.median(all_tokens):.2f}") |
|
print(f"最小token数: {min(all_tokens)}") |
|
print(f"最大token数: {max(all_tokens)}") |
|
print(f"标准差: {np.std(all_tokens):.2f}") |
|
|
|
# 按区间统计 |
|
print("\n=== 区间分布 ===") |
|
max_tokens = max(all_tokens) |
|
max_range = ((max_tokens // 100) + 1) * 100 |
|
|
|
for i in range(0, max_range, 100): |
|
count = sum(1 for x in all_tokens if i <= x < i + 100) |
|
if count > 0: |
|
percentage = (count / len(all_tokens)) * 100 |
|
print(f"{i}-{i + 99} tokens: {count} 句子 ({percentage:.1f}%)") |
|
|
|
|
|
def main(): |
|
"""主函数""" |
|
# 文件路径 |
|
input_file = 'segmentation_results_from_7_retried.json' |
|
|
|
# 1. 加载数据 |
|
sentence_pairs = load_sentence_pairs(input_file) |
|
if sentence_pairs is None: |
|
return |
|
|
|
# 2. 初始化tokenizer |
|
tokenizer = initialize_tokenizer("bert-base-chinese") |
|
if tokenizer is None: |
|
print("无法初始化tokenizer,程序退出") |
|
return |
|
|
|
# 3. 分析token分布 |
|
sentence1_tokens, last_sentence2_tokens, token_details = analyze_token_distribution(sentence_pairs, tokenizer) |
|
|
|
if not sentence1_tokens: |
|
print("没有找到有效的句子数据") |
|
return |
|
|
|
# 4. 查找高token数量的句子 |
|
high_token_sentences = find_high_token_sentences(token_details, threshold=300) |
|
|
|
# 5. 打印详细统计 |
|
# print_detailed_statistics(sentence1_tokens, last_sentence2_tokens) |
|
|
|
# 6. 创建可视化图表 |
|
plt = create_token_distribution_chart(sentence1_tokens, last_sentence2_tokens) |
|
|
|
# 7. 保存和显示图表 |
|
try: |
|
output_file = 'bert_token_distribution.png' |
|
plt.savefig(output_file, dpi=300, bbox_inches='tight') |
|
print(f"\n图表已保存为: {output_file}") |
|
plt.show() |
|
except Exception as e: |
|
print(f"保存或显示图表时出错: {e}") |
|
# 尝试不显示图表,只保存 |
|
try: |
|
plt.savefig('bert_token_distribution.png', dpi=300, bbox_inches='tight') |
|
print("图表已保存,但无法显示") |
|
except Exception as e2: |
|
print(f"保存图表也失败: {e2}") |
|
|
|
print("\n分析完成!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |