You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
344 lines
10 KiB
344 lines
10 KiB
import json |
|
import matplotlib.pyplot as plt |
|
import matplotlib.patches as patches |
|
from collections import Counter |
|
import numpy as np |
|
import warnings |
|
import re |
|
|
|
# 忽略matplotlib警告 |
|
warnings.filterwarnings('ignore') |
|
|
|
# 设置matplotlib后端(避免显示问题) |
|
import matplotlib |
|
|
|
matplotlib.use('Agg') # 使用非交互式后端 |
|
|
|
# 设置中文字体支持 |
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans'] |
|
plt.rcParams['axes.unicode_minus'] = False |
|
|
|
|
|
def diagnose_json_file(file_path): |
|
""" |
|
诊断JSON文件的问题 |
|
|
|
Args: |
|
file_path (str): JSON文件路径 |
|
|
|
Returns: |
|
dict: 诊断结果 |
|
""" |
|
print(f"正在诊断文件:{file_path}") |
|
print("=" * 50) |
|
|
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
content = f.read() |
|
|
|
print(f"文件大小:{len(content)} 字符") |
|
print(f"文件前100个字符:{content[:100]}") |
|
print(f"文件后100个字符:{content[-100:]}") |
|
|
|
# 检查是否为空文件 |
|
if not content.strip(): |
|
print("错误:文件为空") |
|
return {"status": "empty", "content": content} |
|
|
|
# 尝试解析JSON |
|
try: |
|
data = json.loads(content) |
|
print("✓ JSON格式正确") |
|
return {"status": "valid", "data": data} |
|
except json.JSONDecodeError as e: |
|
print(f"✗ JSON格式错误:{e}") |
|
print(f"错误位置:行 {e.lineno}, 列 {e.colno}") |
|
return {"status": "invalid", "error": str(e), "content": content} |
|
|
|
except FileNotFoundError: |
|
print(f"错误:找不到文件 {file_path}") |
|
return {"status": "not_found"} |
|
except Exception as e: |
|
print(f"读取文件时出错:{e}") |
|
return {"status": "error", "error": str(e)} |
|
|
|
|
|
def try_fix_json(content): |
|
""" |
|
尝试修复常见的JSON格式问题 |
|
|
|
Args: |
|
content (str): 原始文件内容 |
|
|
|
Returns: |
|
list: 修复后的数据,如果修复失败则返回None |
|
""" |
|
print("\n尝试修复JSON格式...") |
|
|
|
# 常见修复方法 |
|
fixes = [ |
|
# 1. 如果是JSONL格式(每行一个JSON对象) |
|
lambda x: [json.loads(line) for line in x.strip().split('\n') if line.strip()], |
|
|
|
# 2. 如果缺少最外层的方括号 |
|
lambda x: json.loads('[' + x + ']'), |
|
|
|
# 3. 如果有多个JSON对象但没有用逗号分隔 |
|
lambda x: json.loads('[' + re.sub(r'}\s*{', '},{', x) + ']'), |
|
|
|
# 4. 如果有trailing comma |
|
lambda x: json.loads(re.sub(r',\s*}', '}', re.sub(r',\s*]', ']', x))), |
|
|
|
# 5. 如果单引号而非双引号 |
|
lambda x: json.loads(x.replace("'", '"')), |
|
] |
|
|
|
for i, fix_func in enumerate(fixes, 1): |
|
try: |
|
print(f"尝试修复方法 {i}...") |
|
result = fix_func(content) |
|
if isinstance(result, list) and len(result) > 0: |
|
print(f"✓ 修复成功!找到 {len(result)} 条数据") |
|
return result |
|
elif isinstance(result, dict): |
|
print(f"✓ 修复成功!找到 1 条数据") |
|
return [result] |
|
except Exception as e: |
|
print(f"✗ 修复方法 {i} 失败:{e}") |
|
|
|
print("所有修复方法都失败了") |
|
return None |
|
|
|
|
|
def load_and_analyze_json(file_path): |
|
""" |
|
加载JSON文件并统计标签分布,包含错误处理和修复功能 |
|
|
|
Args: |
|
file_path (str): JSON文件路径 |
|
|
|
Returns: |
|
tuple: (标签统计结果, 总数) |
|
""" |
|
# 首先诊断文件 |
|
diagnosis = diagnose_json_file(file_path) |
|
|
|
if diagnosis["status"] == "not_found": |
|
return None, None |
|
elif diagnosis["status"] == "empty": |
|
print("文件为空,无法分析") |
|
return None, None |
|
elif diagnosis["status"] == "valid": |
|
data = diagnosis["data"] |
|
elif diagnosis["status"] == "invalid": |
|
# 尝试修复 |
|
fixed_data = try_fix_json(diagnosis["content"]) |
|
if fixed_data is None: |
|
print("无法修复JSON格式错误") |
|
return None, None |
|
data = fixed_data |
|
else: |
|
print(f"未知错误:{diagnosis.get('error', '未知')}") |
|
return None, None |
|
|
|
# 确保数据是列表格式 |
|
if not isinstance(data, list): |
|
data = [data] |
|
|
|
print(f"\n成功加载数据,共 {len(data)} 条记录") |
|
|
|
# 检查数据结构 |
|
if len(data) == 0: |
|
print("数据为空") |
|
return None, None |
|
|
|
# 检查第一条数据的结构 |
|
first_item = data[0] |
|
print(f"第一条数据结构:{list(first_item.keys()) if isinstance(first_item, dict) else type(first_item)}") |
|
|
|
# 提取标签 |
|
labels = [] |
|
for i, item in enumerate(data): |
|
if isinstance(item, dict): |
|
if 'label' in item: |
|
labels.append(item['label']) |
|
elif 'Label' in item: |
|
labels.append(item['Label']) |
|
else: |
|
print(f"警告:第 {i + 1} 条数据缺少 'label' 字段:{item}") |
|
else: |
|
print(f"警告:第 {i + 1} 条数据不是字典格式:{item}") |
|
|
|
if not labels: |
|
print("错误:没有找到任何标签数据") |
|
return None, None |
|
|
|
# 统计标签数量 |
|
label_counts = Counter(labels) |
|
total = len(labels) |
|
|
|
# 打印统计结果 |
|
print("=" * 50) |
|
print("标签统计结果:") |
|
print("=" * 50) |
|
print(f"总数据条数:{total}") |
|
print("-" * 30) |
|
|
|
for label, count in sorted(label_counts.items()): |
|
percentage = (count / total) * 100 |
|
print(f"标签 {label}: {count:4d} 条 ({percentage:5.1f}%)") |
|
|
|
return label_counts, total |
|
|
|
|
|
def create_pie_chart(label_counts, total, save_path=None): |
|
""" |
|
创建扇形图 |
|
|
|
Args: |
|
label_counts (dict): 标签统计结果 |
|
total (int): 总数据条数 |
|
save_path (str, optional): 保存图片的路径 |
|
""" |
|
# 准备数据 |
|
labels = [] |
|
sizes = [] |
|
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD'] |
|
|
|
for label, count in sorted(label_counts.items()): |
|
if label == 0: |
|
labels.append(f'不分段 (Label {label})') |
|
else: |
|
labels.append(f'分段 (Label {label})') |
|
sizes.append(count) |
|
|
|
# 创建图形 |
|
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7)) |
|
|
|
# 扇形图 |
|
wedges, texts, autotexts = ax1.pie(sizes, labels=labels, autopct='%1.1f%%', |
|
colors=colors[:len(sizes)], startangle=90, |
|
explode=[0.05] * len(sizes)) |
|
|
|
# 美化扇形图 |
|
ax1.set_title('文本分段标签分布统计', fontsize=16, fontweight='bold', pad=20) |
|
|
|
# 调整文本样式 |
|
for autotext in autotexts: |
|
autotext.set_color('white') |
|
autotext.set_fontweight('bold') |
|
autotext.set_fontsize(12) |
|
|
|
for text in texts: |
|
text.set_fontsize(11) |
|
|
|
# 柱状图 |
|
ax2.bar(range(len(label_counts)), sizes, color=colors[:len(sizes)], alpha=0.7) |
|
ax2.set_title('标签数量柱状图', fontsize=16, fontweight='bold', pad=20) |
|
ax2.set_xlabel('标签类型', fontsize=12) |
|
ax2.set_ylabel('数量', fontsize=12) |
|
|
|
# 设置x轴标签 |
|
ax2.set_xticks(range(len(label_counts))) |
|
ax2.set_xticklabels([f'Label {label}' for label in sorted(label_counts.keys())]) |
|
|
|
# 在柱状图上添加数值标签 |
|
for i, (label, count) in enumerate(sorted(label_counts.items())): |
|
percentage = (count / total) * 100 |
|
ax2.text(i, count + total * 0.01, f'{count}\n({percentage:.1f}%)', |
|
ha='center', va='bottom', fontweight='bold') |
|
|
|
# 调整布局 |
|
plt.tight_layout() |
|
|
|
# 保存图片 |
|
if save_path: |
|
try: |
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
print(f"\n图片已保存到:{save_path}") |
|
except Exception as e: |
|
print(f"保存图片时出错:{e}") |
|
|
|
print("图表生成完成,请查看保存的图片文件。") |
|
|
|
# 关闭图形以释放内存 |
|
plt.close(fig) |
|
|
|
|
|
def create_detailed_report(label_counts, total, file_path): |
|
""" |
|
创建详细报告 |
|
|
|
Args: |
|
label_counts (dict): 标签统计结果 |
|
total (int): 总数据条数 |
|
file_path (str): 原始JSON文件路径 |
|
""" |
|
report = [] |
|
report.append("=" * 60) |
|
report.append("文本分段标签分布统计报告") |
|
report.append("=" * 60) |
|
report.append(f"数据源文件:{file_path}") |
|
report.append(f"分析时间:{np.datetime64('now', 'D')}") |
|
report.append(f"总数据条数:{total}") |
|
report.append("") |
|
|
|
report.append("标签分布详情:") |
|
report.append("-" * 40) |
|
|
|
for label, count in sorted(label_counts.items()): |
|
percentage = (count / total) * 100 |
|
label_desc = "不分段" if label == 0 else "分段" |
|
report.append(f"Label {label} ({label_desc}):{count:4d} 条 ({percentage:5.1f}%)") |
|
|
|
report.append("") |
|
report.append("标签含义说明:") |
|
report.append("- Label 0:两句话不需要分段,属于同一段落") |
|
report.append("- Label 1:两句话需要分段,属于不同段落") |
|
|
|
# 打印报告 |
|
for line in report: |
|
print(line) |
|
|
|
# 保存报告到文件 |
|
report_file = file_path.replace('.json', '_analysis_report.txt') |
|
try: |
|
with open(report_file, 'w', encoding='utf-8') as f: |
|
f.write('\n'.join(report)) |
|
print(f"\n详细报告已保存到:{report_file}") |
|
except Exception as e: |
|
print(f"保存报告时出错:{e}") |
|
|
|
|
|
def main(): |
|
"""主函数""" |
|
# JSON文件路径 |
|
json_file = 'test_dataset.json' |
|
|
|
print("JSON文件分析工具 - 增强版") |
|
print("=" * 50) |
|
|
|
# 加载并分析数据 |
|
label_counts, total = load_and_analyze_json(json_file) |
|
|
|
if label_counts is not None: |
|
# 创建扇形图 |
|
image_path = json_file.replace('.json', '_pie_chart.png') |
|
create_pie_chart(label_counts, total, image_path) |
|
|
|
# 创建详细报告 |
|
create_detailed_report(label_counts, total, json_file) |
|
|
|
print("\n分析完成!") |
|
print("=" * 50) |
|
else: |
|
print("分析失败,请检查文件内容和格式。") |
|
print("\n建议:") |
|
print("1. 确保文件存在且不为空") |
|
print("2. 检查JSON格式是否正确") |
|
print("3. 确保每条数据都有'label'字段") |
|
print("4. 如果是JSONL格式,确保每行都是有效的JSON对象") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |