Browse Source

初始提交:广播内容拆条项目

master
Xiang.L 2 months ago
commit
6916d406e2
  1. 581
      1.csv时间戳-content/ai_ai_broadcast_info.csv
  2. 1834
      1.csv时间戳-content/ai_ai_broadcast_split.csv
  3. 239
      1.csv时间戳-content/combined.py
  4. 574
      1.csv时间戳-content/merged_results_all.csv
  5. 10070
      1.csv时间戳-content/txt-csv/1.带时间戳json/all_timeline_data.json
  6. 181
      1.csv时间戳-content/txt-csv/1.带时间戳json/txt_json.py
  7. 486
      1.csv时间戳-content/txt-csv/2.合并content方便清洗/API合并/merged_texts619-1103.csv
  8. 206
      1.csv时间戳-content/txt-csv/2.合并content方便清洗/API合并/merged_texts619-1103.py
  9. 51
      2.content-merged清洗/batch_deduplication_results_619-1103_01.csv
  10. 51
      2.content-merged清洗/merged.csv
  11. 408
      2.content-merged清洗/merged_results清洗.py
  12. 10
      2.content-merged清洗/封装/test.json
  13. 310
      2.content-merged清洗/封装/去重.py
  14. 6782
      3.merged-结构化json/all_sentence_pairs_for_annotation.json
  15. 542
      3.merged-结构化json/batch_deduplication_report_619-1103_01.txt
  16. 51
      3.merged-结构化json/batch_deduplication_results_619-1103_01.csv
  17. 155
      3.merged-结构化json/json_数据标注.py
  18. 514
      4.结构化json-Ai标注/AIbiaozhu.py
  19. 581
      4.结构化json-Ai标注/Ai-continue.py
  20. 192
      4.结构化json-Ai标注/Data_plaus/Data_plaus.py
  21. 8418
      4.结构化json-Ai标注/Data_plaus/cross_document_boundaries.json
  22. 340806
      4.结构化json-Ai标注/Data_plaus/enhanced_training_data_with_boundaries.json
  23. 332390
      4.结构化json-Ai标注/Data_plaus/segmentation_results_from_7_retried.json
  24. 6782
      4.结构化json-Ai标注/all_sentence_pairs_for_annotation.json
  25. 542
      4.结构化json-Ai标注/batch_deduplication_report_619-1103_01.txt
  26. 51
      4.结构化json-Ai标注/batch_deduplication_results_619-1103_01.csv
  27. 7912
      4.结构化json-Ai标注/segmentation_results_from_7.json
  28. 332390
      4.结构化json-Ai标注/segmentation_results_from_7_retried.json
  29. 44907
      4.结构化json-Ai标注/test_dataset.json
  30. 219
      4.结构化json-Ai标注/train and test.py
  31. 287485
      4.结构化json-Ai标注/train_dataset.json
  32. 344
      4.结构化json-Ai标注/统计/Stat.py
  33. 293
      4.结构化json-Ai标注/统计/Token_Stat.py
  34. 6782
      4.结构化json-Ai标注/统计/all_sentence_pairs_for_annotation.json
  35. BIN
      4.结构化json-Ai标注/统计/bert_token_distribution.png
  36. 11
      4.结构化json-Ai标注/统计/high_token_sentences_over_300.csv
  37. 332390
      4.结构化json-Ai标注/统计/segmentation_results_from_7_retried.json
  38. BIN
      4.结构化json-Ai标注/统计/segmentation_results_from_7_retried_pie_chart.png
  39. 1431
      5.AI标注-model_trian/LoRa+NN/失败案例/train-robert-wwm-ext.py
  40. 1188
      5.AI标注-model_trian/全参+NN/train-robert-large.py
  41. 1431
      5.AI标注-model_trian/全参+NN/train-robert-wwm-ext-new.py
  42. 1431
      5.AI标注-model_trian/全参+NN/train-robert-wwm-ext.py
  43. 1686
      5.AI标注-model_trian/全参微调/FreeLB扰动训练/Bert-train_FreeLB.py
  44. 957
      5.AI标注-model_trian/全参微调/无验证集训练/Bert-train.py
  45. 990
      5.AI标注-model_trian/全参微调/无验证集训练/train-continue.py
  46. 1208
      5.AI标注-model_trian/全参微调/无验证集训练/模型按验证集准确率选择/Bert-train-plaus.py
  47. 1165
      5.AI标注-model_trian/全参微调/有验证集训练/Bert-test_eval.py
  48. 1169
      5.AI标注-model_trian/全参微调/有验证集训练/Bert-testeval_continue.py
  49. 1401
      5.AI标注-model_trian/全参微调/有验证集训练/模型按验证集准确率选择/Bert-test_evalplaus-continue.py
  50. 1296
      5.AI标注-model_trian/全参微调/有验证集训练/模型按验证集准确率选择/Bert-test_evalplaus.py
  51. 656
      5.AI标注-model_trian/封装模型/Fastapi.py
  52. 996
      5.AI标注-model_trian/封装模型/Project1/app.py
  53. 656
      5.AI标注-model_trian/无标点符号训练/Fastapi.py
  54. 1431
      5.AI标注-model_trian/无标点符号训练/train-robert-wwm-ext.py
  55. 67
      6.model_train-test/API/API-test.py
  56. 664
      6.model_train-test/API/Fastapi.py
  57. 247
      6.model_train-test/API/test-合并.py
  58. 10
      6.model_train-test/API/test.json
  59. 658
      6.model_train-test/Project.py
  60. BIN
      6.model_train-test/test/class_performance_analysis.png
  61. BIN
      6.model_train-test/test/confusion_matrix_test_results.png
  62. 52208
      6.model_train-test/test/detailed_predictions.json
  63. 70
      6.model_train-test/test/test_results_detailed.json
  64. 41
      6.model_train-test/test/test_summary.json
  65. 1059
      FinalData/batch_deduplication_results_8-1103.csv
  66. 51
      FinalData/batch_deduplication_results_new50.csv
  67. 51
      FinalData/merged-new50.csv
  68. 324480
      FinalData/segmentation_results_from_7-1103_retried.json
  69. 332390
      FinalData/segmentation_results_from_7-1153_retried.json

581
1.csv时间戳-content/ai_ai_broadcast_info.csv

File diff suppressed because one or more lines are too long

1834
1.csv时间戳-content/ai_ai_broadcast_split.csv

File diff suppressed because one or more lines are too long

239
1.csv时间戳-content/combined.py

@ -0,0 +1,239 @@
import requests
import json
import pandas as pd
import csv
from typing import List, Dict
import time
# %%
# 读取CSV文件
csv_file_path = "ai_ai_broadcast_info.csv"
# 尝试不同的编码格式
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin-1', 'cp1252']
df = None
for encoding in encodings:
try:
print(f"尝试使用 {encoding} 编码读取文件...")
df = pd.read_csv(csv_file_path, encoding=encoding)
print(f"成功使用 {encoding} 编码读取CSV文件,共 {len(df)} 行数据")
print(f"列名:{list(df.columns)}")
break
except UnicodeDecodeError as e:
print(f" {encoding} 编码失败:{str(e)}")
continue
except Exception as e:
print(f" {encoding} 编码读取时出现其他错误:{str(e)}")
continue
if df is None:
print("错误:尝试了所有编码格式都无法读取文件")
exit()
# 检查是否有id和content列
if 'id' not in df.columns:
print("错误:CSV文件中没有找到'id'")
print(f"可用列:{list(df.columns)}")
exit()
elif 'content' not in df.columns:
print("错误:CSV文件中没有找到'content'")
print(f"可用列:{list(df.columns)}")
exit()
print(f"数据加载完成,可用的ID值:{sorted(df['id'].unique())}")
# %%
class SimpleOpenAIHubClient:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.openai-hub.com"
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
def chat(self, prompt, model="gpt-4.1"):
"""发送prompt并返回模型回答"""
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": prompt
}
],
"max_tokens": 100000,
"temperature": 0.7
}
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=60
)
if response.status_code == 200:
result = response.json()
return result['choices'][0]['message']['content']
else:
return f"错误: {response.status_code} - {response.text}"
except requests.exceptions.RequestException as e:
return f"请求异常: {str(e)}"
print("AI客户端类定义完成!")
# %%
# 设置API Key
API_KEY = "sk-XREp2jnIXyZ6UoCnzZeO0ahmLi9OEXuVAtFLojKFpG9gCZ4e" # 请替换为你的实际API Key
# 初始化AI客户端
client = SimpleOpenAIHubClient(API_KEY)
print("AI模型加载完成!")
# %%
prompt_template = """任务:文本合并
要求
1. 提取JSON数组中所有"text"字段的内容
2. 按原始顺序直接拼接
3. 不修改任何文字
4. 不添加标点符号
5. 不做错误纠正
6. 只输出合并后的文本
输入数据
<audio_text> %s </audio_text>
输出"""
print("Prompt模板设置完成!")
# %%
# 批量处理ID 8-26的数据
target_ids = list(range(8, 618)) # 8到26(包含26)
results = []
print(f"开始批量处理ID {target_ids[0]}{target_ids[-1]} 的数据...")
print(f"目标ID列表:{target_ids}")
# 统计可用的ID
available_ids = df['id'].unique()
missing_ids = [id for id in target_ids if id not in available_ids]
processable_ids = [id for id in target_ids if id in available_ids]
if missing_ids:
print(f"警告:以下ID在数据中不存在:{missing_ids}")
print(f"将要处理的ID:{processable_ids}")
# 循环处理每个ID
for current_id in target_ids:
print(f"\n--- 处理ID {current_id} ---")
# 查找对应ID的行
target_row = df[df['id'] == current_id]
if len(target_row) == 0:
print(f"警告:没有找到id={current_id}的数据行")
results.append({
'id': current_id,
'original_content': "",
'merged_text': "",
'status': "数据不存在"
})
continue
# 获取content内容
target_content = target_row['content'].iloc[0]
content = str(target_content) if pd.notna(target_content) else ""
print(f"Content内容长度:{len(content)}")
print(f"Content预览:{content[:100]}..." if content else "Content为空")
if not content or content == 'nan':
print("Content为空,无法处理")
results.append({
'id': current_id,
'original_content': content,
'merged_text': "",
'status': "内容为空"
})
continue
# 构建prompt
prompt = prompt_template % content
print("正在调用AI模型处理...")
# 调用AI模型
try:
ai_response = client.chat(prompt)
results.append({
'id': current_id,
'original_content': content,
'merged_text': ai_response,
'status': "成功"
})
print("处理完成!")
print(f"合并后文本预览:{ai_response[:100]}...")
# 添加延时,避免API调用过于频繁
time.sleep(1)
except Exception as e:
print(f"处理出错:{str(e)}")
results.append({
'id': current_id,
'original_content': content,
'merged_text': "",
'status': f"处理出错: {str(e)}"
})
print(f"\n=== 批量处理完成!===")
print(f"总共尝试处理:{len(target_ids)} 条记录")
print(f"实际处理完成:{len(results)} 条记录")
# %%
# 生成处理结果报告
result_df = pd.DataFrame(results)
# 显示处理结果统计
print("\n处理结果统计:")
status_counts = result_df['status'].value_counts()
print(status_counts)
# 显示各个状态的ID
for status in status_counts.index:
status_ids = result_df[result_df['status'] == status]['id'].tolist()
print(f"{status}的ID:{status_ids}")
# 显示结果预览
print("\n结果预览(前5行):")
preview_df = result_df[['id', 'status', 'merged_text']].copy()
preview_df['merged_text_preview'] = preview_df['merged_text'].apply(
lambda x: str(x)[:100] + "..." if len(str(x)) > 100 else str(x)
)
print(preview_df[['id', 'status', 'merged_text_preview']].head())
# 保存到CSV文件
output_file = "merged_results_all.csv"
result_df.to_csv(output_file, index=False, encoding='utf-8-sig')
print(f"\n结果已保存到文件:{output_file}")
# 成功处理的统计
successful_results = [r for r in results if r['status'] == '成功']
print(f"\n最终统计:")
print(f"成功处理:{len(successful_results)} 条记录")
print(f"失败处理:{len(results) - len(successful_results)} 条记录")
if successful_results:
successful_ids = [r['id'] for r in successful_results]
print(f"成功处理的ID:{successful_ids}")

574
1.csv时间戳-content/merged_results_all.csv

File diff suppressed because one or more lines are too long

10070
1.csv时间戳-content/txt-csv/1.带时间戳json/all_timeline_data.json

File diff suppressed because it is too large Load Diff

181
1.csv时间戳-content/txt-csv/1.带时间戳json/txt_json.py

@ -0,0 +1,181 @@
import json
import re
import os
from datetime import datetime
def time_to_milliseconds(time_str):
"""将时间字符串转换为毫秒"""
# 解析时间格式 HH:MM:SS
parts = time_str.split(':')
hours = int(parts[0])
minutes = int(parts[1])
seconds = int(parts[2])
# 转换为毫秒
total_ms = (hours * 3600 + minutes * 60 + seconds) * 1000
return total_ms
def parse_timeline_file(file_path, fixed_id=1104):
"""解析时间轴文本文件"""
result = []
try:
with open(file_path, 'r', encoding='utf-8') as file:
content = file.read().strip()
# 按行分割内容
lines = content.split('\n')
i = 0
while i < len(lines):
line = lines[i].strip()
# 检查是否是时间轴格式:HH:MM:SS-HH:MM:SS
time_match = re.match(r'(\d{2}:\d{2}:\d{2})-(\d{2}:\d{2}:\d{2})', line)
if time_match:
start_time_str = time_match.group(1)
end_time_str = time_match.group(2)
start_time_ms = time_to_milliseconds(start_time_str)
end_time_ms = time_to_milliseconds(end_time_str)
# 获取下一行作为内容(如果存在)
content_text = ""
if i + 1 < len(lines):
content_text = lines[i + 1].strip()
i += 1 # 跳过内容行
# 创建JSON对象
if content_text: # 只有当内容不为空时才添加
json_obj = {
"d_id": fixed_id,
"start_time": start_time_ms,
"end_time": end_time_ms,
"content": content_text
}
result.append(json_obj)
i += 1
except FileNotFoundError:
print(f"文件未找到: {file_path}")
return []
except Exception as e:
print(f"处理文件时出错: {e}")
return []
return result
def get_txt_files(folder_path):
"""获取文件夹中所有的txt文件"""
txt_files = []
try:
for filename in os.listdir(folder_path):
if filename.lower().endswith('.txt'):
full_path = os.path.join(folder_path, filename)
txt_files.append((filename, full_path))
# 按文件名排序,确保处理顺序一致
txt_files.sort(key=lambda x: x[0])
return txt_files
except Exception as e:
print(f"读取文件夹时出错: {e}")
return []
def save_to_json(data, output_path):
"""保存为JSON文件"""
try:
with open(output_path, 'w', encoding='utf-8') as file:
json.dump(data, file, ensure_ascii=False, indent=2)
print(f"JSON文件已保存: {output_path}")
except Exception as e:
print(f"保存JSON文件时出错: {e}")
def batch_process_txt_files(folder_path, start_id=1104):
"""批量处理文件夹中的txt文件"""
txt_files = get_txt_files(folder_path)
if not txt_files:
print("未找到任何txt文件")
return
print(f"找到 {len(txt_files)} 个txt文件:")
for i, (filename, _) in enumerate(txt_files):
print(f"{i + 1}. {filename} (d_id: {start_id + i})")
all_data = []
file_summary = []
for i, (filename, file_path) in enumerate(txt_files):
current_id = start_id + i
print(f"\n正在处理: {filename} (d_id: {current_id})")
# 解析单个文件
file_data = parse_timeline_file(file_path, current_id)
if file_data:
all_data.extend(file_data)
file_summary.append({
"filename": filename,
"d_id": current_id,
"segments": len(file_data)
})
print(f"成功解析 {len(file_data)} 个数据段")
else:
print(f"文件 {filename} 未能解析到有效数据")
# 保存合并的JSON文件
if all_data:
output_file = os.path.join(folder_path, "all_timeline_data.json")
save_to_json(all_data, output_file)
# 保存处理摘要
summary_file = os.path.join(folder_path, "processing_summary.json")
summary_data = {
"total_files": len(txt_files),
"total_segments": len(all_data),
"start_id": start_id,
"end_id": start_id + len(txt_files) - 1,
"files": file_summary
}
save_to_json(summary_data, summary_file)
print(f"\n=== 处理完成 ===")
print(f"总文件数: {len(txt_files)}")
print(f"总数据段: {len(all_data)}")
print(f"ID范围: {start_id} - {start_id + len(txt_files) - 1}")
print(f"合并文件: all_timeline_data.json")
print(f"摘要文件: processing_summary.json")
# # 分别保存每个文件的JSON
# print(f"\n正在保存单独的JSON文件...")
# for i, (filename, file_path) in enumerate(txt_files):
# current_id = start_id + i
# file_data = parse_timeline_file(file_path, current_id)
# if file_data:
# json_filename = filename.replace('.txt', '.json')
# json_path = os.path.join(folder_path, json_filename)
# save_to_json(file_data, json_path)
#
# print("所有单独的JSON文件已保存完成")
else:
print("没有解析到任何有效数据")
def main():
# 批量处理文件夹中的txt文件
folder_path = r"D:\workstation\voice-txt\ct-punc-test\ASR+punc\staic-应急宣传"
start_id = 1104
print("开始批量处理txt文件...")
batch_process_txt_files(folder_path, start_id)
if __name__ == "__main__":
main()

486
1.csv时间戳-content/txt-csv/2.合并content方便清洗/API合并/merged_texts619-1103.csv

File diff suppressed because one or more lines are too long

206
1.csv时间戳-content/txt-csv/2.合并content方便清洗/API合并/merged_texts619-1103.py

@ -0,0 +1,206 @@
import json
import requests
import csv
from collections import defaultdict
class SimpleOpenAIHubClient:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.openai-hub.com"
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
def chat(self, prompt, model="gpt-4.1"):
"""发送prompt并返回模型回答"""
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": prompt
}
],
"max_tokens": 100000,
"temperature": 0.7
}
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=60
)
if response.status_code == 200:
result = response.json()
return result['choices'][0]['message']['content']
else:
return f"错误: {response.status_code} - {response.text}"
except requests.exceptions.RequestException as e:
return f"请求异常: {str(e)}"
def load_json_data(file_path):
"""加载JSON数据"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
data = json.load(file)
print(f"成功加载数据,共 {len(data)} 条记录")
return data
except Exception as e:
print(f"加载JSON文件时出错: {e}")
return []
def group_by_id(data):
"""按d_id分组数据"""
grouped = defaultdict(list)
for item in data:
d_id = item.get('d_id')
content = item.get('content', '')
if d_id is not None and content:
grouped[d_id].append({
'start_time': item.get('start_time'),
'end_time': item.get('end_time'),
'content': content
})
# 对每个组内的数据按start_time排序
for d_id in grouped:
grouped[d_id].sort(key=lambda x: x['start_time'])
return grouped
def create_prompt_for_group(group_data):
"""为每个组创建AI prompt"""
# 构建类似JSON的文本数组格式
text_array = []
for i, item in enumerate(group_data):
# 转义双引号
escaped_content = item["content"].replace('"', '""')
text_array.append(f'{{"text": "{escaped_content}"}}')
json_like_text = "[" + ", ".join(text_array) + "]"
prompt_template = """任务:文本合并
要求
1. 提取JSON数组中所有"text"字段的内容
2. 按原始顺序直接拼接
3. 不修改任何文字
4. 不添加标点符号
5. 不做错误纠正
6. 只输出合并后的文本
输入数据
<audio_text> %s </audio_text>
输出"""
return prompt_template % json_like_text
def merge_texts_directly(grouped_data):
"""直接按时间顺序合并文本,不使用AI模型"""
merged_results = {}
total_groups = len(grouped_data)
print(f"开始处理 {total_groups} 个组...")
for i, (d_id, group_data) in enumerate(grouped_data.items(), 1):
print(f"处理组 {i}/{total_groups}: d_id={d_id}, 包含{len(group_data)}个片段")
# 直接拼接所有文本内容
merged_text = ""
for item in group_data:
merged_text += item['content']
# 存储结果
merged_results[d_id] = {
'original_content': json.dumps(
[{"end": item['end_time'], "start": item['start_time'], "text": item['content']} for item in
group_data], ensure_ascii=False),
'merged_text': merged_text,
'segments_count': len(group_data)
}
print(f"完成: d_id={d_id}")
print(f"合并结果预览: {merged_text[:100]}...")
print("-" * 50)
return merged_results
def save_to_csv(merged_results, output_path):
"""保存为CSV文件"""
try:
with open(output_path, 'w', newline='', encoding='utf-8') as csvfile:
fieldnames = ['id', 'original_content', 'merged_text', 'status']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
# 写入表头
writer.writeheader()
# 按id排序并写入数据
for d_id in sorted(merged_results.keys()):
result = merged_results[d_id]
writer.writerow({
'id': d_id,
'original_content': result['original_content'],
'merged_text': result['merged_text'],
'status': '成功'
})
print(f"CSV文件已保存到: {output_path}")
return True
except Exception as e:
print(f"保存CSV文件时出错: {e}")
return False
def main():
# 输入文件路径
input_file = r"D:\workstation\Data\广播\all_data619-1103.json"
# 输出文件路径
output_file = r"D:\workstation\Data\广播\merged_texts619-1103.csv"
# 加载数据
print("加载JSON数据...")
data = load_json_data(input_file)
if not data:
print("无数据可处理")
return
# 按d_id分组
print("按d_id分组数据...")
grouped_data = group_by_id(data)
print(f"共分为 {len(grouped_data)} 个组")
# 显示分组统计
print("\n分组统计:")
for d_id, group_data in list(grouped_data.items())[:5]: # 显示前5个组的信息
print(f"d_id {d_id}: {len(group_data)} 个片段")
if len(grouped_data) > 5:
print(f"... 还有 {len(grouped_data) - 5} 个组")
# 直接合并文本(不使用AI)
print("\n开始直接文本合并...")
merged_results = merge_texts_directly(grouped_data)
# 保存结果
print("\n保存合并结果...")
if save_to_csv(merged_results, output_file):
print("✅ 所有任务完成!")
print(f"📁 输入文件: {input_file}")
print(f"📄 输出文件: {output_file}")
print(f"📊 处理统计: {len(grouped_data)} 个组,{len(data)} 个原始片段")
else:
print("❌ 保存结果失败")
if __name__ == "__main__":
main()

51
2.content-merged清洗/batch_deduplication_results_619-1103_01.csv

File diff suppressed because one or more lines are too long

51
2.content-merged清洗/merged.csv

File diff suppressed because one or more lines are too long

408
2.content-merged清洗/merged_results清洗.py

@ -0,0 +1,408 @@
import pandas as pd
import re
from difflib import SequenceMatcher
from collections import Counter
import chardet
def detect_file_encoding(file_path):
"""检测文件编码"""
with open(file_path, 'rb') as f:
raw_data = f.read(10000) # 读取前10KB来检测编码
result = chardet.detect(raw_data)
return result['encoding']
def safe_read_csv(file_path):
"""安全读取CSV文件,自动检测编码"""
# 尝试多种编码方式
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin1']
# 首先尝试自动检测编码
try:
detected_encoding = detect_file_encoding(file_path)
if detected_encoding:
encodings.insert(0, detected_encoding)
print(f"检测到文件编码: {detected_encoding}")
except:
print("编码检测失败,使用默认编码列表")
# 尝试不同编码读取文件
for encoding in encodings:
try:
print(f"尝试使用编码 {encoding} 读取文件...")
df = pd.read_csv(file_path, encoding=encoding)
print(f"成功使用编码 {encoding} 读取文件")
return df
except UnicodeDecodeError:
print(f"编码 {encoding} 失败")
continue
except Exception as e:
print(f"使用编码 {encoding} 时出现其他错误: {e}")
continue
# 如果所有编码都失败,尝试忽略错误的方式
try:
print("尝试使用 utf-8 编码并忽略错误...")
df = pd.read_csv(file_path, encoding='utf-8', errors='ignore')
print("成功读取文件(忽略了一些字符)")
return df
except Exception as e:
raise Exception(f"无法读取文件 {file_path}: {e}")
def clean_text(text):
# 统一换行符和空格
text = re.sub(r'\r\n|\r|\n', ' ', text)
text = re.sub(r'\s+', ' ', text) # 多个空格合并为一个
# 去除HTML标签(如果存在)
text = re.sub(r'<[^>]+>', '', text)
# 【修改点】保留中文、英文、数字、标点符号 (增加了顿号 `、`)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9、,%.。!?;:""''()【】\-\s]', '', text)
# 标点符号规范化 (初次)
punctuation_map = {
',,': '',
'..': '',
',。': '',
',。': '',
'!!': '',
'??': '',
';;': ''
}
for old, new in punctuation_map.items():
text = text.replace(old, new)
return text.strip()
def remove_paragraph_duplicates(text, similarity_threshold=0.85):
"""
段落级别去重基于相似度去除重复段落
"""
paragraphs = text.split('')
paragraphs = [p.strip() for p in paragraphs if p.strip() and len(p) > 0]
unique_paragraphs = []
removed_paragraphs = []
for paragraph in paragraphs:
is_similar = False
for existing in unique_paragraphs:
similarity = SequenceMatcher(None, paragraph, existing).ratio()
if similarity > similarity_threshold:
is_similar = True
if len(paragraph) > len(existing):
removed_paragraphs.append(f"段落替换: {existing[:50]}...")
unique_paragraphs[unique_paragraphs.index(existing)] = paragraph
else:
removed_paragraphs.append(f"段落重复: {paragraph[:50]}...")
break
if not is_similar:
unique_paragraphs.append(paragraph)
return ''.join(unique_paragraphs), removed_paragraphs
def remove_sentence_duplicates(text, similarity_threshold=0.9):
"""
句子级别去重去除重复的句子
"""
# 句子切分时,也可以考虑加入顿号,但这可能会切分得过细,这里暂时不修改
sentences = re.split(r'[。!?;]', text)
sentences = [s.strip() for s in sentences if s.strip() and len(s) > 0]
unique_sentences = []
removed_sentences = []
for sentence in sentences:
is_duplicate = False
for existing in unique_sentences:
similarity = SequenceMatcher(None, sentence, existing).ratio()
if similarity > similarity_threshold:
is_duplicate = True
if len(sentence) > len(existing):
removed_sentences.append(f"句子替换: {existing[:30]}...")
unique_sentences[unique_sentences.index(existing)] = sentence
else:
removed_sentences.append(f"句子重复: {sentence[:30]}...")
break
if not is_duplicate:
unique_sentences.append(sentence)
result = []
for sentence in unique_sentences:
if sentence:
if any(word in sentence for word in ['', '提醒', '注意', '防止']):
result.append(sentence + '')
elif '' in sentence or sentence.endswith('') or sentence.endswith(''):
result.append(sentence + '')
elif any(word in sentence for word in ['', '重要', '紧急', '警告']):
result.append(sentence + '')
else:
result.append(sentence + '')
return ''.join(result), removed_sentences
def remove_phrase_duplicates(text, min_phrase_length=4, max_phrase_length=20):
"""
短语级别去重去除重复的短语和词组
"""
words = re.findall(r'[\u4e00-\u9fa5a-zA-Z0-9]+', text)
phrases = []
for n in range(min_phrase_length, min(max_phrase_length + 1, len(words) + 1)):
for i in range(len(words) - n + 1):
phrase = ''.join(words[i:i + n])
if len(phrase) >= min_phrase_length:
phrases.append(phrase)
phrase_counts = Counter(phrases)
frequent_phrases = [(phrase, count) for phrase, count in phrase_counts.items()
if count >= 3 and len(phrase) >= 6]
cleaned_text = text
removed_phrases = []
for phrase, count in sorted(frequent_phrases, key=lambda x: len(x[0]), reverse=True):
if phrase in cleaned_text:
first_occurrence = cleaned_text.find(phrase)
remaining_text = cleaned_text[first_occurrence + len(phrase):]
removed_count = remaining_text.count(phrase)
if removed_count > 0:
cleaned_text = cleaned_text[:first_occurrence + len(phrase)] + remaining_text.replace(phrase, '')
removed_phrases.append(f"短语重复({removed_count}次): {phrase}")
return cleaned_text, removed_phrases
def comprehensive_deduplication(text):
"""
综合去重按层级顺序进行多级别去重并在最后进行标点规范化
"""
original_length = len(text)
# 1. 段落级别去重
print("1. 执行段落级别去重...")
text, paragraph_removed = remove_paragraph_duplicates(text, 0.85)
paragraph_length = len(text)
print(f" 段落去重后长度: {paragraph_length} (减少 {original_length - paragraph_length} 字符)")
# 2. 句子级别去重
print("2. 执行句子级别去重...")
text, sentence_removed = remove_sentence_duplicates(text, 0.9)
sentence_length = len(text)
print(f" 句子去重后长度: {sentence_length} (减少 {paragraph_length - sentence_length} 字符)")
# 3. 短语级别去重
print("3. 执行短语级别去重...")
text, phrase_removed = remove_phrase_duplicates(text, 4, 15)
phrase_length = len(text)
print(f" 短语去重后长度: {phrase_length} (减少 {sentence_length - phrase_length} 字符)")
# 4. 最终标点符号规范化
print("4. 执行最终标点符号规范化...")
punctuation_map = {
',,': '',
'..': '',
',。': '',
',。': '',
'!!': '',
'??': '',
';;': ''
}
final_text = text
for old, new in punctuation_map.items():
final_text = final_text.replace(old, new)
final_length = len(final_text)
print(f" 最终规范化后长度: {final_length} (减少 {phrase_length - final_length} 字符)")
# 生成详细报告
report = {
'original_length': original_length,
'after_paragraph': paragraph_length,
'after_sentence': sentence_length,
'after_phrase': phrase_length,
'final_length': final_length,
'total_reduction': original_length - final_length,
'reduction_ratio': (original_length - final_length) / original_length if original_length > 0 else 0,
'removed_items': {
'paragraphs': paragraph_removed,
'sentences': sentence_removed,
'phrases': phrase_removed
}
}
return final_text, report
# 主处理流程
def main():
print("开始多级别去重处理...\n")
# 读取CSV文件
try:
df = safe_read_csv('merged.csv')
except Exception as e:
print(f"读取CSV文件失败: {e}")
return
print(f"读取到CSV文件,共 {len(df)} 行数据")
print(f"CSV文件列名: {list(df.columns)}")
if 'id' in df.columns:
print(f"可用的ID列表: {sorted(df['id'].unique())}")
else:
print("警告:CSV文件中没有找到'id'")
print("请检查CSV文件格式")
return
# 准备结果列表
all_results = []
all_reports = []
# 遍历所有ID
for current_id in sorted(df['id'].unique()):
print(f"\n{'=' * 50}")
print(f"处理ID: {current_id}")
print(f"{'=' * 50}")
target_row = df[df['id'] == current_id]
if len(target_row) == 0:
print(f"警告:没有找到ID={current_id}的数据")
continue
if 'merged_text' not in target_row.columns:
print(f"错误:找不到merged_text列")
continue
original_text = target_row['merged_text'].iloc[0]
if pd.isna(original_text) or str(original_text).strip() == '':
print(f"警告:ID={current_id}的merged_text为空,跳过处理")
all_results.append({
'id': current_id, 'original_text': '', 'cleaned_text': '', 'final_processed_text': '',
'original_length': 0, 'cleaned_length': 0, 'final_length': 0, 'paragraph_reduction': 0,
'sentence_reduction': 0, 'phrase_reduction': 0, 'punctuation_reduction': 0,
'total_reduction': 0, 'reduction_ratio': 0
})
continue
print(f"原始文本长度: {len(original_text)} 字符")
try:
print("执行基础文本清洗...")
cleaned_text = clean_text(str(original_text))
print(f"清洗后文本长度: {len(cleaned_text)} 字符")
final_text, dedup_report = comprehensive_deduplication(cleaned_text)
print(f"处理完成")
print(f"总体压缩比: {dedup_report['reduction_ratio']:.2%}")
print(f"最终文本长度: {dedup_report['final_length']} 字符")
result_record = {
'id': current_id,
'original_text': original_text,
'cleaned_text': cleaned_text,
'final_processed_text': final_text,
'original_length': len(str(original_text)),
'cleaned_length': len(cleaned_text),
'final_length': len(final_text),
'paragraph_reduction': dedup_report['original_length'] - dedup_report['after_paragraph'],
'sentence_reduction': dedup_report['after_paragraph'] - dedup_report['after_sentence'],
'phrase_reduction': dedup_report['after_sentence'] - dedup_report['after_phrase'],
'punctuation_reduction': dedup_report['after_phrase'] - dedup_report['final_length'],
'total_reduction': dedup_report['total_reduction'],
'reduction_ratio': dedup_report['reduction_ratio']
}
all_results.append(result_record)
all_reports.append((current_id, dedup_report))
except Exception as e:
print(f"处理ID={current_id}时出错: {str(e)}")
all_results.append({
'id': current_id, 'original_text': str(original_text), 'cleaned_text': '', 'final_processed_text': '',
'original_length': len(str(original_text)), 'cleaned_length': 0, 'final_length': 0,
'paragraph_reduction': 0, 'sentence_reduction': 0, 'phrase_reduction': 0, 'punctuation_reduction': 0,
'total_reduction': 0, 'reduction_ratio': 0
})
print(f"\n{'=' * 60}")
print("所有ID处理完成!")
print(f"{'=' * 60}")
result_df = pd.DataFrame(all_results)
print(f"总共处理: {len(all_results)} 个ID")
print(f"成功处理: {len([r for r in all_results if r['final_length'] > 0])} 个ID")
print(f"处理失败或跳过: {len([r for r in all_results if r['final_length'] == 0])} 个ID")
if len(all_results) > 0:
avg_reduction = result_df['reduction_ratio'].mean()
print(f"平均压缩比: {avg_reduction:.2%}")
print(f"总原始字符数: {result_df['original_length'].sum()}")
print(f"总最终字符数: {result_df['final_length'].sum()}")
try:
result_df.to_csv('batch_deduplication_results_619-1103_01.csv', index=False, encoding='utf-8-sig')
print("结果已保存到: batch_deduplication_results_619-1103_01.csv")
except Exception as e:
print(f"保存结果CSV时出错: {e}")
try:
with open('batch_deduplication_report_619-1103_01.txt', 'w', encoding='utf-8') as f:
f.write("=== 批量多级别去重详细报告 ===\n\n")
f.write(f"处理日期: {pd.Timestamp.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"总共处理: {len(all_results)} 个ID\n")
f.write(f"成功处理: {len([r for r in all_results if r['final_length'] > 0])} 个ID\n\n")
if len(all_results) > 0:
f.write("总体统计:\n")
f.write(f"- 平均压缩比: {result_df['reduction_ratio'].mean():.2%}\n")
f.write(f"- 总原始字符数: {result_df['original_length'].sum():,}\n")
f.write(f"- 总最终字符数: {result_df['final_length'].sum():,}\n")
f.write(f"- 总减少字符数: {result_df['total_reduction'].sum():,}\n\n")
for id_num, report in all_reports:
f.write(f"\n--- ID {id_num} 详细报告 ---\n")
f.write(f"原始文本长度: {report['original_length']} 字符\n")
f.write(f"最终文本长度: {report['final_length']} 字符\n")
f.write(f"总体压缩比: {report['reduction_ratio']:.2%}\n")
f.write("各级别处理效果:\n")
f.write(f"1. 段落级去重: 减少 {report['original_length'] - report['after_paragraph']} 字符\n")
f.write(f"2. 句子级去重: 减少 {report['after_paragraph'] - report['after_sentence']} 字符\n")
f.write(f"3. 短语级去重: 减少 {report['after_sentence'] - report['after_phrase']} 字符\n")
f.write(f"4. 最终标点规范化: 减少 {report['after_phrase'] - report['final_length']} 字符\n")
for level, items in report['removed_items'].items():
if items:
f.write(f"{level.upper()}级别移除了 {len(items)} 项内容\n")
print("详细报告已保存到: batch_deduplication_report_619-1103.txt")
except Exception as e:
print(f"保存报告时出错: {e}")
print(f"\n结果预览:")
print(result_df[['id', 'original_length', 'final_length', 'reduction_ratio']].head(10))
if __name__ == "__main__":
main()

10
2.content-merged清洗/封装/test.json

File diff suppressed because one or more lines are too long

310
2.content-merged清洗/封装/去重.py

@ -0,0 +1,310 @@
import json
import re
import chardet
from difflib import SequenceMatcher
from collections import Counter
from typing import Union, List, Dict, Any
import os
class BroadcastDeduplicator:
"""广播去重处理类"""
def __init__(self):
pass
def detect_file_encoding(self, file_path: str) -> str:
with open(file_path, 'rb') as f:
raw_data = f.read(10000)
result = chardet.detect(raw_data)
return result['encoding']
def safe_read_json(self, file_path: str) -> Union[Dict, List]:
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin1']
try:
detected_encoding = self.detect_file_encoding(file_path)
if detected_encoding:
encodings.insert(0, detected_encoding)
print(f"检测到文件编码: {detected_encoding}")
except:
print("编码检测失败,使用默认编码列表")
for encoding in encodings:
try:
print(f"尝试使用编码 {encoding} 读取文件...")
with open(file_path, 'r', encoding=encoding) as f:
data = json.load(f)
print(f"成功使用编码 {encoding} 读取文件")
return data
except UnicodeDecodeError:
print(f"编码 {encoding} 失败")
continue
except json.JSONDecodeError as e:
print(f"JSON格式错误: {e}")
raise
except Exception as e:
print(f"使用编码 {encoding} 时出现其他错误: {e}")
continue
raise Exception(f"无法读取文件 {file_path}")
def clean_text(self, text: str) -> str:
if not isinstance(text, str):
return str(text)
text = re.sub(r'\r\n|\r|\n', ' ', text)
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z0-9、,%.。!~?;:""''()【】\-\s]', '', text)
punctuation_map = {
',,': '',
'..': '',
',。': '',
',。': '',
'!!': '',
'??': '',
';;': ''
}
for old, new in punctuation_map.items():
text = text.replace(old, new)
return text.strip()
def remove_paragraph_duplicates(self, text: str, similarity_threshold: float = 0.85) -> str:
paragraphs = text.split('')
paragraphs = [p.strip() for p in paragraphs if p.strip() and len(p) > 0]
unique_paragraphs = []
for paragraph in paragraphs:
is_similar = False
for existing in unique_paragraphs:
similarity = SequenceMatcher(None, paragraph, existing).ratio()
if similarity > similarity_threshold:
is_similar = True
if len(paragraph) > len(existing):
unique_paragraphs[unique_paragraphs.index(existing)] = paragraph
break
if not is_similar:
unique_paragraphs.append(paragraph)
return ''.join(unique_paragraphs)
def remove_sentence_duplicates(self, text: str, similarity_threshold: float = 0.9) -> str:
sentences = re.split(r'[。!?;]', text)
sentences = [s.strip() for s in sentences if s.strip() and len(s) > 0]
unique_sentences = []
for sentence in sentences:
is_duplicate = False
for existing in unique_sentences:
similarity = SequenceMatcher(None, sentence, existing).ratio()
if similarity > similarity_threshold:
is_duplicate = True
if len(sentence) > len(existing):
unique_sentences[unique_sentences.index(existing)] = sentence
break
if not is_duplicate:
unique_sentences.append(sentence)
result = []
for sentence in unique_sentences:
if sentence:
if any(word in sentence for word in ['', '提醒', '注意', '防止']):
result.append(sentence + '')
elif '' in sentence or sentence.endswith('') or sentence.endswith(''):
result.append(sentence + '')
elif any(word in sentence for word in ['', '重要', '紧急', '警告']):
result.append(sentence + '')
else:
result.append(sentence + '')
return ''.join(result)
def remove_phrase_duplicates(self, text: str, min_phrase_length: int = 4, max_phrase_length: int = 20) -> str:
words = re.findall(r'[\u4e00-\u9fa5a-zA-Z0-9]+', text)
phrases = []
for n in range(min_phrase_length, min(max_phrase_length + 1, len(words) + 1)):
for i in range(len(words) - n + 1):
phrase = ''.join(words[i:i + n])
if len(phrase) >= min_phrase_length:
phrases.append(phrase)
phrase_counts = Counter(phrases)
frequent_phrases = [(phrase, count) for phrase, count in phrase_counts.items()
if count >= 3 and len(phrase) >= 6]
cleaned_text = text
for phrase, count in sorted(frequent_phrases, key=lambda x: len(x[0]), reverse=True):
if phrase in cleaned_text:
first_occurrence = cleaned_text.find(phrase)
remaining_text = cleaned_text[first_occurrence + len(phrase):]
removed_count = remaining_text.count(phrase)
if removed_count > 0:
cleaned_text = cleaned_text[:first_occurrence + len(phrase)] + remaining_text.replace(phrase, '')
return cleaned_text
def comprehensive_deduplication(self, text: str) -> str:
# 1. 文本清理
text = self.clean_text(text)
# 2. 段落级别去重
text = self.remove_paragraph_duplicates(text, 0.85)
# 3. 句子级别去重
text = self.remove_sentence_duplicates(text, 0.9)
# 4. 短语级别去重
text = self.remove_phrase_duplicates(text, 4, 15)
# 5. 最终标点符号规范化
punctuation_map = {
',,': '',
'..': '',
',。': '',
',。': '',
'!!': '',
'??': '',
';;': ''
}
for old, new in punctuation_map.items():
text = text.replace(old, new)
return text
def process_single_broadcast(self, broadcast_data: Dict[str, Any]) -> Dict[str, Any]:
broadcast_id = broadcast_data.get('广播ID', 'unknown')
content = broadcast_data.get('广播内容', '')
print(f"处理广播ID: {broadcast_id}")
if not content:
return {
'broadcast_id': broadcast_id,
'original_content': content,
'deduplicated_content': content,
'processing_status': 'empty_content'
}
try:
deduplicated_content = self.comprehensive_deduplication(content)
return {
'broadcast_id': broadcast_id,
'original_content': content,
'deduplicated_content': deduplicated_content,
'processing_status': 'success'
}
except Exception as e:
print(f"处理广播ID {broadcast_id} 时出错: {str(e)}")
return {
'broadcast_id': broadcast_id,
'original_content': content,
'deduplicated_content': content,
'processing_status': 'error'
}
def process_broadcast_data(self, input_file: str = 'test.json', output_file: str = 'deduplication_results.json'):
try:
# 读取输入文件
print(f"读取输入文件: {input_file}")
data = self.safe_read_json(input_file)
results = []
# 判断数据类型并处理
if isinstance(data, dict):
# 单条广播
print("检测到单条广播数据")
result = self.process_single_broadcast(data)
results.append(result)
elif isinstance(data, list):
# 广播数组
print(f"检测到广播数组,共 {len(data)} 条广播")
for i, broadcast in enumerate(data, 1):
print(f"处理第 {i}/{len(data)} 条广播")
result = self.process_single_broadcast(broadcast)
results.append(result)
else:
raise ValueError("不支持的数据格式,请提供单条广播对象或广播数组")
simplified_results = []
successful_count = 0
for result in results:
if result['processing_status'] == 'success':
simplified_item = {
'broadcast_id': result['broadcast_id'],
'original_content': result['original_content'],
'deduplicated_content': result['deduplicated_content']
}
simplified_results.append(simplified_item)
successful_count += 1
# 输出处理统计
print(f"\n处理完成!")
print(f"总计处理: {len(results)} 条广播")
print(f"成功处理: {successful_count}")
print(f"处理失败: {len(results) - successful_count}")
# 保存简化结果
print(f"\n保存简化结果到: {output_file}")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(simplified_results, f, ensure_ascii=False, indent=2)
print("处理完成!")
return simplified_results
except Exception as e:
print(f"处理过程中出现错误: {str(e)}")
raise
def main():
deduplicator = BroadcastDeduplicator()
# 检查输入文件是否存在
input_file = 'test.json'
if not os.path.exists(input_file):
print(f"输入文件 {input_file} 不存在!")
print("请创建包含广播数据的 test.json 文件")
print("\n支持的格式示例:")
print("1. 单条广播:")
print('{"广播内容": "今天天气很好。今天天气很好。", "广播ID": "broadcast_001"}')
print("\n2. 广播数组:")
print('[{"广播内容": "第一条...", "广播ID": "001"}, {"广播内容": "第二条...", "广播ID": "002"}]')
return
try:
results = deduplicator.process_broadcast_data(input_file, 'deduplication_results.json')
print(f"\n简化结果已保存到 deduplication_results.json")
print(f"成功处理了 {len(results)} 条广播")
except Exception as e:
print(f"程序执行失败: {str(e)}")
if __name__ == "__main__":
main()

6782
3.merged-结构化json/all_sentence_pairs_for_annotation.json

File diff suppressed because it is too large Load Diff

542
3.merged-结构化json/batch_deduplication_report_619-1103_01.txt

@ -0,0 +1,542 @@
=== 批量多级别去重详细报告 ===
处理日期: 2025-08-11 17:04:28
总共处理: 50 个ID
成功处理: 50 个ID
总体统计:
- 平均压缩比: 24.59%
- 总原始字符数: 108,025
- 总最终字符数: 57,951
- 总减少字符数: 50,038
--- ID 1104 详细报告 ---
原始文本长度: 791 字符
最终文本长度: 791 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1105 详细报告 ---
原始文本长度: 791 字符
最终文本长度: 791 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1106 详细报告 ---
原始文本长度: 7591 字符
最终文本长度: 801 字符
总体压缩比: 89.45%
各级别处理效果:
1. 段落级去重: 减少 6791 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 161 项内容
--- ID 1107 详细报告 ---
原始文本长度: 19 字符
最终文本长度: 19 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1108 详细报告 ---
原始文本长度: 3738 字符
最终文本长度: 1248 字符
总体压缩比: 66.61%
各级别处理效果:
1. 段落级去重: 减少 2491 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 60 项内容
--- ID 1109 详细报告 ---
原始文本长度: 4841 字符
最终文本长度: 4841 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1110 详细报告 ---
原始文本长度: 177 字符
最终文本长度: 104 字符
总体压缩比: 41.24%
各级别处理效果:
1. 段落级去重: 减少 74 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1111 详细报告 ---
原始文本长度: 212 字符
最终文本长度: 212 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1112 详细报告 ---
原始文本长度: 190 字符
最终文本长度: 116 字符
总体压缩比: 38.95%
各级别处理效果:
1. 段落级去重: 减少 75 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1113 详细报告 ---
原始文本长度: 1282 字符
最终文本长度: 1282 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1114 详细报告 ---
原始文本长度: 5262 字符
最终文本长度: 5262 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1115 详细报告 ---
原始文本长度: 5328 字符
最终文本长度: 2005 字符
总体压缩比: 62.37%
各级别处理效果:
1. 段落级去重: 减少 2707 字符
2. 句子级去重: 减少 616 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 33 项内容
SENTENCES级别移除了 7 项内容
--- ID 1116 详细报告 ---
原始文本长度: 5127 字符
最终文本长度: 5117 字符
总体压缩比: 0.20%
各级别处理效果:
1. 段落级去重: 减少 11 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1117 详细报告 ---
原始文本长度: 400 字符
最终文本长度: 400 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1118 详细报告 ---
原始文本长度: 1296 字符
最终文本长度: 817 字符
总体压缩比: 36.96%
各级别处理效果:
1. 段落级去重: 减少 480 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 11 项内容
--- ID 1119 详细报告 ---
原始文本长度: 445 字符
最终文本长度: 284 字符
总体压缩比: 36.18%
各级别处理效果:
1. 段落级去重: 减少 162 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 2 项内容
--- ID 1120 详细报告 ---
原始文本长度: 795 字符
最终文本长度: 422 字符
总体压缩比: 46.92%
各级别处理效果:
1. 段落级去重: 减少 374 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 10 项内容
--- ID 1121 详细报告 ---
原始文本长度: 796 字符
最终文本长度: 424 字符
总体压缩比: 46.73%
各级别处理效果:
1. 段落级去重: 减少 373 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 10 项内容
--- ID 1122 详细报告 ---
原始文本长度: 125 字符
最终文本长度: 125 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1123 详细报告 ---
原始文本长度: 37 字符
最终文本长度: 37 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1124 详细报告 ---
原始文本长度: 3675 字符
最终文本长度: 3175 字符
总体压缩比: 13.61%
各级别处理效果:
1. 段落级去重: 减少 501 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 14 项内容
--- ID 1125 详细报告 ---
原始文本长度: 498 字符
最终文本长度: 249 字符
总体压缩比: 50.00%
各级别处理效果:
1. 段落级去重: 减少 250 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1126 详细报告 ---
原始文本长度: 2461 字符
最终文本长度: 486 字符
总体压缩比: 80.25%
各级别处理效果:
1. 段落级去重: 减少 1976 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 40 项内容
--- ID 1127 详细报告 ---
原始文本长度: 2442 字符
最终文本长度: 1120 字符
总体压缩比: 54.14%
各级别处理效果:
1. 段落级去重: 减少 1323 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 15 项内容
--- ID 1128 详细报告 ---
原始文本长度: 2560 字符
最终文本长度: 1779 字符
总体压缩比: 30.51%
各级别处理效果:
1. 段落级去重: 减少 782 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 8 项内容
--- ID 1129 详细报告 ---
原始文本长度: 2561 字符
最终文本长度: 1788 字符
总体压缩比: 30.18%
各级别处理效果:
1. 段落级去重: 减少 774 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 7 项内容
--- ID 1130 详细报告 ---
原始文本长度: 673 字符
最终文本长度: 673 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1131 详细报告 ---
原始文本长度: 264 字符
最终文本长度: 264 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1132 详细报告 ---
原始文本长度: 1566 字符
最终文本长度: 1442 字符
总体压缩比: 7.92%
各级别处理效果:
1. 段落级去重: 减少 125 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 2 项内容
--- ID 1133 详细报告 ---
原始文本长度: 1559 字符
最终文本长度: 1559 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1134 详细报告 ---
原始文本长度: 2510 字符
最终文本长度: 356 字符
总体压缩比: 85.82%
各级别处理效果:
1. 段落级去重: 减少 2155 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 141 项内容
--- ID 1135 详细报告 ---
原始文本长度: 2530 字符
最终文本长度: 380 字符
总体压缩比: 84.98%
各级别处理效果:
1. 段落级去重: 减少 2151 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 158 项内容
--- ID 1136 详细报告 ---
原始文本长度: 251 字符
最终文本长度: 251 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1137 详细报告 ---
原始文本长度: 3153 字符
最终文本长度: 571 字符
总体压缩比: 81.89%
各级别处理效果:
1. 段落级去重: 减少 2583 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 147 项内容
--- ID 1138 详细报告 ---
原始文本长度: 917 字符
最终文本长度: 883 字符
总体压缩比: 3.71%
各级别处理效果:
1. 段落级去重: 减少 35 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1139 详细报告 ---
原始文本长度: 908 字符
最终文本长度: 857 字符
总体压缩比: 5.62%
各级别处理效果:
1. 段落级去重: 减少 52 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1140 详细报告 ---
原始文本长度: 2797 字符
最终文本长度: 1656 字符
总体压缩比: 40.79%
各级别处理效果:
1. 段落级去重: 减少 1142 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 25 项内容
--- ID 1141 详细报告 ---
原始文本长度: 800 字符
最终文本长度: 800 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1142 详细报告 ---
原始文本长度: 618 字符
最终文本长度: 598 字符
总体压缩比: 3.24%
各级别处理效果:
1. 段落级去重: 减少 21 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1143 详细报告 ---
原始文本长度: 1330 字符
最终文本长度: 732 字符
总体压缩比: 44.96%
各级别处理效果:
1. 段落级去重: 减少 599 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 11 项内容
--- ID 1144 详细报告 ---
原始文本长度: 22010 字符
最终文本长度: 1494 字符
总体压缩比: 93.21%
各级别处理效果:
1. 段落级去重: 减少 20517 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 342 项内容
--- ID 1145 详细报告 ---
原始文本长度: 42 字符
最终文本长度: 42 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1146 详细报告 ---
原始文本长度: 771 字符
最终文本长度: 771 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1147 详细报告 ---
原始文本长度: 1183 字符
最终文本长度: 1183 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1148 详细报告 ---
原始文本长度: 1184 字符
最终文本长度: 1184 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1149 详细报告 ---
原始文本长度: 3964 字符
最终文本长度: 3964 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1150 详细报告 ---
原始文本长度: 1263 字符
最终文本长度: 1191 字符
总体压缩比: 5.70%
各级别处理效果:
1. 段落级去重: 减少 73 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 3 项内容
--- ID 1151 详细报告 ---
原始文本长度: 1611 字符
最终文本长度: 1524 字符
总体压缩比: 5.40%
各级别处理效果:
1. 段落级去重: 减少 88 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 8 项内容
--- ID 1152 详细报告 ---
原始文本长度: 1810 字符
最终文本长度: 1046 字符
总体压缩比: 42.21%
各级别处理效果:
1. 段落级去重: 减少 765 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 16 项内容
--- ID 1153 详细报告 ---
原始文本长度: 835 字符
最终文本长度: 835 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符

51
3.merged-结构化json/batch_deduplication_results_619-1103_01.csv

File diff suppressed because one or more lines are too long

155
3.merged-结构化json/json_数据标注.py

@ -0,0 +1,155 @@
import re
import json
import pandas as pd
import os
def split_sentences(text):
# 使用捕获组来保留分隔符
parts = re.split(r'([。!?])', text)
# 重新组合句子和标点符号
sentences = []
for i in range(0, len(parts), 2):
if i < len(parts) and parts[i].strip():
# 如果有对应的标点符号,就加上
punctuation = parts[i + 1] if i + 1 < len(parts) else ''
sentence = parts[i].strip() + punctuation
sentences.append(sentence)
return sentences
def create_sentence_pairs(sentences):
pairs = []
for i in range(len(sentences) - 1):
pair = {
"sentence1": sentences[i],
"sentence2": sentences[i + 1],
"label": -1 # 待标注
}
pairs.append(pair)
return pairs
# 从CSV文件中读取所有内容
try:
# 尝试不同的编码格式读取CSV文件
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin1']
df = None
for encoding in encodings:
try:
print(f"尝试使用 {encoding} 编码读取文件...")
df = pd.read_csv('batch_deduplication_results_619-1103_01.csv', encoding=encoding)
print(f"成功使用 {encoding} 编码读取文件")
break
except UnicodeDecodeError:
continue
if df is None:
print("错误:尝试了所有常见编码都无法读取文件")
exit()
except FileNotFoundError:
print("错误:找不到文件 'batch_deduplication_results_619-1103_01.csv'")
exit()
except Exception as e:
print(f"读取CSV文件时发生错误:{e}")
exit()
# 创建输出目录
output_dir = 'sentence_pairs_output_all'
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 汇总所有数据
all_sentence_pairs = []
summary_info = []
print(f"CSV文件共有 {len(df)} 行数据")
print("开始遍历所有ID...")
# 遍历所有行
for index, row in df.iterrows():
try:
current_id = row['id']
raw_text = row['final_processed_text']
# 检查文本是否为空
if pd.isna(raw_text) or str(raw_text).strip() == '':
print(f"ID {current_id}: 文本内容为空,跳过")
summary_info.append({
'id': current_id,
'status': '文本为空',
'sentences_count': 0,
'pairs_count': 0
})
continue
# 执行分割和配对
sentences = split_sentences(str(raw_text))
sentence_pairs = create_sentence_pairs(sentences)
# 为每个句子对添加来源ID
for pair in sentence_pairs:
pair['source_id'] = current_id
# 添加到汇总数据
all_sentence_pairs.extend(sentence_pairs)
# 为每个ID单独保存文件
if sentence_pairs: # 只有当有句子对时才保存
filename = f'sentence_pairs_id_{current_id}.json'
filepath = os.path.join(output_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(sentence_pairs, f, ensure_ascii=False, indent=2)
# 记录处理信息
summary_info.append({
'id': current_id,
'status': '成功处理',
'sentences_count': len(sentences),
'pairs_count': len(sentence_pairs),
'text_length': len(str(raw_text))
})
print(f"ID {current_id}: 分割出 {len(sentences)} 个句子,生成 {len(sentence_pairs)} 个句子对")
except Exception as e:
print(f"处理ID {current_id} 时发生错误:{e}")
summary_info.append({
'id': current_id,
'status': f'错误: {str(e)}',
'sentences_count': 0,
'pairs_count': 0
})
# 保存汇总的所有句子对数据
print("\n保存汇总数据...")
with open('all_sentence_pairs_for_annotation.json', 'w', encoding='utf-8') as f:
json.dump(all_sentence_pairs, f, ensure_ascii=False, indent=2)
# 保存处理摘要
summary_df = pd.DataFrame(summary_info)
summary_df.to_csv('processing_summary.csv', index=False, encoding='utf-8-sig')
# 统计信息
total_sentences = sum([info['sentences_count'] for info in summary_info])
total_pairs = sum([info['pairs_count'] for info in summary_info])
successful_ids = len([info for info in summary_info if info['status'] == '成功处理'])
print(f"\n=== 处理完成 ===")
print(f"总计处理了 {len(df)} 个ID")
print(f"成功处理 {successful_ids} 个ID")
print(f"总计分割出 {total_sentences} 个句子")
print(f"总计生成 {total_pairs} 个句子对")
print(f"汇总数据保存到: all_sentence_pairs_for_annotation.json")
print(f"单独文件保存在: {output_dir}/ 目录")
print(f"处理摘要保存到: processing_summary.csv")
# 显示前几个句子对的示例
if all_sentence_pairs:
print("\n前3个句子对示例:")
for i in range(min(3, len(all_sentence_pairs))):
print(f"\n{i + 1}对 (来源ID: {all_sentence_pairs[i]['source_id']}):")
print(f"句子1: {all_sentence_pairs[i]['sentence1']}")
print(f"句子2: {all_sentence_pairs[i]['sentence2']}")
print(f"标签: {all_sentence_pairs[i]['label']}")

514
4.结构化json-Ai标注/AIbiaozhu.py

@ -0,0 +1,514 @@
import requests
import json
import pandas as pd
from typing import List, Dict
import time
class SimpleOpenAIHubClient:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.openai-hub.com"
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
def chat(self, prompt, model="gpt-4.1"):
"""发送prompt并返回模型回答"""
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": prompt
}
],
"max_tokens": 32768,
"temperature": 0.7
}
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=60
)
if response.status_code == 200:
result = response.json()
return result['choices'][0]['message']['content']
else:
return f"错误: {response.status_code} - {response.text}"
except requests.exceptions.RequestException as e:
return f"请求异常: {str(e)}"
print("AI客户端类定义完成!")
# 设置API Key
API_KEY = "sk-XREp2jnIXyZ6UoCnzZeO0ahmLi9OEXuVAtFLojKFpG9gCZ4e" # 请替换为你的实际API Key
# 初始化AI客户端
client = SimpleOpenAIHubClient(API_KEY)
print("AI模型加载完成!")
# 定义批量标注的Prompt模板
BATCH_SEGMENTATION_PROMPT = """你是一个专业的广播内容段落分割标注员。你的任务是批量判断多个相邻句子对之间是否应该进行段落分割,以便广播员更好地掌握停顿和语调变化。
**完整文本内容上下文**
{context_text}
**标注规则**
- 标签0两个句子属于同一段落连续播报轻微停顿
- 标签1两个句子属于不同段落需要明显停顿或语调转换
**重要标注要求请严格遵循**
- 如果整个文本内容都在讲同一个事你有理由只输出一段不是追求分的段越多越细就越好
- 每个分段必须保持原始语句的绝对顺序
- 最终分段数可能等于或小于原始语句数量
- 必须保留所有原始语句文本不得遗漏任何内容
- 应客户强烈要求他们需要的是较粗的分段不要太细如同一条通告不需要分段成具体的每个条款之类的只需要将整个相同的通告分成一段
- 优先考虑较粗的分段避免过度细分
**广播分段判断依据偏向粗分段**
1. **重大主题转换**从一个完全不同的话题转向另一个话题如从天气预报转向安全通知
2. **文档类型变化**从一个完整文档转向另一个完整文档如从禁火令转向倡议书
3. **内容性质变化**从通知类内容转向完全不同性质的内容如从法规转向天气预报
4. **广播节目段落**明显的广播节目结构变化如开场白结束进入正式内容
5. **分点阐述结构**标题和所有分点条目内容应该合并为一个完整段落"森林防火十不准,一不乱扔烟头,二不随意丢弃火种,三不在林区吸烟"等整体合成一段
**广播内容特别注意粗分段原则**
- 整个通告法令倡议书等应作为一个段落不要拆分条款
- 同一主题的多个条款应保持在同一段落
- 只有在完全不同的文档或重大主题转换时才分段
- 广播开场白可以独立成段但具体内容尽量合并
- 同一类型的预报信息如天气预报的不同地区应保持在同一段
- **分点阐述内容的特殊处理**
- 标题性内容"森林防火十不准"与分点条目内容之间不需要分段
- 标题和所有的分点条目"一不乱扔烟头""二不随意丢弃火种""三不在林区吸烟"应该合并为一个完整段落
- 分点条目之间不需要分段应该连续播报
- 整个分点阐述结构作为一个完整的内容单元保持连贯性
**批量标注说明**
- 每个句子对都有一个source_id表示来源文档
- 请保持原有的source_id不变
- 将label从-1改为实际的标注结果0或1
- 为每个句子对提供简要的分段理由
- 结合上述完整文本内容理解句子对的上下文语境
- **特别重要倾向于标注更多的0同段落减少1分段的使用分点阐述结构应保持为一个完整段落**
现在请对以下句子对进行批量标注
{batch_sentence_pairs}
请直接输出标注结果格式如下
```json
[
{{
"sentence1": "...",
"sentence2": "...",
"label": 0或1,
"reason": "广播分段理由",
"source_id": 原有的source_id
}}
]
```
只输出JSON数据不要其他说明文字"""
def load_context_data(csv_file="batch_deduplication_results_619-1103_01.csv"):
"""
从batch_deduplication_results_batch_deduplication_results_619-1103_01.csv加载上下文数据
Args:
csv_file: CSV文件路径
Returns:
字典key为idvalue为final_processed_text
"""
try:
print(f"正在读取上下文数据文件: {csv_file}")
# 尝试不同的编码格式
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin-1', 'cp1252']
context_df = None
for encoding in encodings:
try:
print(f" 尝试使用 {encoding} 编码...")
context_df = pd.read_csv(csv_file, encoding=encoding)
print(f" ✓ 成功使用 {encoding} 编码读取文件")
break
except UnicodeDecodeError:
print(f" {encoding} 编码失败")
continue
except Exception as e:
print(f" {encoding} 编码读取时出现其他错误:{str(e)}")
continue
if context_df is None:
print(f"✗ 错误:尝试了所有编码格式都无法读取文件 {csv_file}")
return {}
if 'id' not in context_df.columns or 'final_processed_text' not in context_df.columns:
print(f"✗ 错误:CSV文件缺少必需列")
print(f" 需要的列: ['id', 'final_processed_text']")
print(f" 实际的列: {list(context_df.columns)}")
return {}
# 创建id到final_processed_text的映射
context_dict = {}
for _, row in context_df.iterrows():
context_dict[row['id']] = row['final_processed_text'] if pd.notna(row['final_processed_text']) else ""
print(f"✓ 成功加载上下文数据")
print(f" - 可用ID数量: {len(context_dict)}")
print(f" - 可用ID列表: {sorted(context_dict.keys())}")
return context_dict
except FileNotFoundError:
print(f"✗ 警告:找不到上下文文件 {csv_file}")
print(" 将在没有上下文的情况下进行标注")
return {}
except Exception as e:
print(f"✗ 读取上下文文件时出错: {str(e)}")
print(" 将在没有上下文的情况下进行标注")
return {}
def process_batch_segmentation(sentence_pairs_data, context_dict, batch_size=10):
"""
批量处理句子对的段落分割标注
Args:
sentence_pairs_data: 句子对数据列表
context_dict: 上下文数据字典
batch_size: 每批处理的数量
Returns:
处理结果列表
"""
all_results = []
total_pairs = len(sentence_pairs_data)
print(f"开始批量标注,总共 {total_pairs} 个句子对")
print(f"每批处理 {batch_size} 个句子对")
# 分批处理
for i in range(0, total_pairs, batch_size):
batch_end = min(i + batch_size, total_pairs)
current_batch = sentence_pairs_data[i:batch_end]
print(f"\n处理第 {i // batch_size + 1} 批 (句子对 {i + 1}-{batch_end})")
try:
# 获取当前批次涉及的source_id的上下文
source_ids_in_batch = set(pair['source_id'] for pair in current_batch)
context_text = ""
for source_id in sorted(source_ids_in_batch):
if source_id in context_dict and context_dict[source_id]:
context_text += f"\n--- Source ID {source_id} 完整文本内容 ---\n"
context_text += context_dict[source_id] # 完整内容,不截断
context_text += "\n"
else:
context_text += f"\n--- Source ID {source_id} ---\n(未找到对应的完整文本内容)\n"
# 准备当前批次的数据
batch_json = json.dumps(current_batch, ensure_ascii=False, indent=2)
# 构建prompt
prompt = BATCH_SEGMENTATION_PROMPT.format(
context_text=context_text,
batch_sentence_pairs=batch_json
)
print(f"发送请求到AI模型...")
print(f" - 涉及source_id: {sorted(source_ids_in_batch)}")
print(f" - 上下文长度: {len(context_text)} 字符")
# 调用AI模型
ai_response = client.chat(prompt)
print(f"收到模型响应")
# 尝试解析JSON响应
try:
# 提取JSON部分(去除可能的markdown格式)
json_start = ai_response.find('[')
json_end = ai_response.rfind(']') + 1
if json_start != -1 and json_end != 0:
json_content = ai_response[json_start:json_end]
batch_deduplication_results_all = json.loads(json_content)
# 验证结果
if isinstance(batch_deduplication_results_all, list) and len(batch_deduplication_results_all) == len(current_batch):
all_results.extend(batch_deduplication_results_all)
print(f"✓ 成功处理 {len(batch_deduplication_results_all)} 个句子对")
else:
print(
f"✗ 响应格式不正确,期望 {len(current_batch)} 个结果,实际得到 {len(batch_deduplication_results_all) if isinstance(batch_deduplication_results_all, list) else 'non-list'}")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": "处理失败:响应格式错误",
"source_id": pair["source_id"]
})
else:
print(f"✗ 无法找到有效的JSON响应")
print(f"原始响应前200字符: {ai_response[:200]}...")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": "处理失败:JSON解析错误",
"source_id": pair["source_id"]
})
except json.JSONDecodeError as e:
print(f"✗ JSON解析失败: {str(e)}")
print(f"原始响应: {ai_response[:200]}...")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"处理失败:{str(e)}",
"source_id": pair["source_id"]
})
# 添加延时,避免API调用过于频繁
time.sleep(2)
except Exception as e:
print(f"✗ 批次处理出错: {str(e)}")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"处理异常:{str(e)}",
"source_id": pair["source_id"]
})
return all_results
# 执行批量标注
print("=" * 60)
print("开始执行批量段落分割标注(从source_id 7开始)")
print("=" * 60)
# 加载上下文数据
context_dict = load_context_data("batch_deduplication_results_619-1103_01.csv")
# 从JSON文件加载数据
input_file = "all_sentence_pairs_for_annotation.json"
try:
print(f"正在读取数据文件: {input_file}")
with open(input_file, 'r', encoding='utf-8') as f:
all_sentence_pairs_data = json.load(f)
print(f"✓ 成功加载数据文件")
print(f" - 总句子对数量: {len(all_sentence_pairs_data)}")
# 检查数据格式
if len(all_sentence_pairs_data) > 0:
sample_item = all_sentence_pairs_data[0]
required_fields = ['sentence1', 'sentence2', 'source_id']
missing_fields = [field for field in required_fields if field not in sample_item]
if missing_fields:
print(f"✗ 数据格式错误,缺少字段: {missing_fields}")
print(f" 实际字段: {list(sample_item.keys())}")
exit()
# 获取所有unique的source_id
all_source_ids = sorted(set(item.get('source_id') for item in all_sentence_pairs_data))
print(f"✓ 发现的source_id列表: {all_source_ids}")
# 【修改】:筛选出source_id >= 7的ID
filtered_source_ids = [sid for sid in all_source_ids if sid >= 7]
print(f"✓ 筛选后的source_id列表(>=7): {filtered_source_ids}")
if not filtered_source_ids:
print("✗ 没有找到source_id >= 7的数据")
exit()
# 统计各source_id的句子对数量
from collections import Counter
source_counts = Counter(item.get('source_id') for item in all_sentence_pairs_data)
print(f" - 各source_id的句子对数量(>=7):")
for source_id in filtered_source_ids:
print(f" source_id {source_id}: {source_counts[source_id]}")
# 检查上下文数据可用性
print(f"\n上下文数据可用性检查:")
available_context_ids = []
missing_context_ids = []
for source_id in filtered_source_ids:
if source_id in context_dict and context_dict[source_id]:
available_context_ids.append(source_id)
print(f" ✓ source_id {source_id}: 上下文长度 {len(context_dict[source_id])} 字符")
else:
missing_context_ids.append(source_id)
print(f" ✗ source_id {source_id}: 缺少上下文数据")
if missing_context_ids:
print(f"\n警告:以下source_id缺少上下文数据: {missing_context_ids}")
print("这些ID的标注可能不够准确")
print(f"\n开始处理source_id >= 7的标注任务...")
except FileNotFoundError:
print(f"✗ 错误:找不到文件 {input_file}")
print("请确保文件存在于当前目录中")
exit()
except json.JSONDecodeError as e:
print(f"✗ 错误:JSON文件格式错误 - {str(e)}")
exit()
except Exception as e:
print(f"✗ 错误:读取文件时出现异常 - {str(e)}")
exit()
# 准备所有结果
all_results = []
# 【修改】:遍历筛选后的source_id(>=7)
for current_source_id in filtered_source_ids:
print(f"\n{'=' * 50}")
print(f"处理 source_id: {current_source_id}")
print(f"{'=' * 50}")
# 筛选当前source_id的数据
current_sentence_pairs = [item for item in all_sentence_pairs_data if item.get('source_id') == current_source_id]
if len(current_sentence_pairs) == 0:
print(f"✗ 警告:source_id={current_source_id} 没有找到句子对数据")
continue
print(f" - 当前source_id的句子对数量: {len(current_sentence_pairs)}")
# 显示第一个句子对预览
if len(current_sentence_pairs) > 0:
first_pair = current_sentence_pairs[0]
print(f" - 第一个句子对预览:")
print(f" 句子1: {first_pair['sentence1'][:60]}...")
print(f" 句子2: {first_pair['sentence2'][:60]}...")
# 检查上下文数据
if current_source_id in context_dict and context_dict[current_source_id]:
print(f" - 上下文数据: 可用,长度 {len(context_dict[current_source_id])} 字符")
print(f" - 上下文预览: {context_dict[current_source_id][:100]}...")
else:
print(f" - 上下文数据: 不可用")
try:
print(f" - 开始处理标注...")
# 执行批量标注(每批处理5个句子对)
current_results = process_batch_segmentation(current_sentence_pairs, context_dict, batch_size=8)
# 添加到总结果中
all_results.extend(current_results)
# 统计当前source_id的结果
current_successful = len([r for r in current_results if r['label'] in [0, 1]])
current_failed = len([r for r in current_results if r['label'] == -1])
print(f" ✓ source_id {current_source_id} 处理完成")
print(f" - 成功标注: {current_successful}")
print(f" - 失败数量: {current_failed}")
if current_successful > 0:
current_label_0 = len([r for r in current_results if r['label'] == 0])
current_label_1 = len([r for r in current_results if r['label'] == 1])
print(f" - 标签0(同段落): {current_label_0}")
print(f" - 标签1(分段): {current_label_1}")
# 添加延时,避免处理过快
time.sleep(1)
except Exception as e:
print(f" ✗ source_id {current_source_id} 处理失败: {str(e)}")
# 为失败的source_id添加错误记录
for pair in current_sentence_pairs:
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"source_id处理异常:{str(e)}",
"source_id": pair["source_id"]
})
print(f"\n{'=' * 60}")
print("所有source_id(>=7)处理完成!")
print(f"{'=' * 60}")
# 使用所有结果进行后续处理
results = all_results
print(f"\n{'=' * 60}")
print("批量标注完成!")
print(f"{'=' * 60}")
# 统计结果
total_processed = len(results)
successful_labels = len([r for r in results if r['label'] in [0, 1]])
failed_labels = len([r for r in results if r['label'] == -1])
print(f"总处理数量: {total_processed}")
print(f"成功标注: {successful_labels}")
print(f"失败数量: {failed_labels}")
if successful_labels > 0:
label_0_count = len([r for r in results if r['label'] == 0])
label_1_count = len([r for r in results if r['label'] == 1])
print(f"标签0(同段落): {label_0_count}")
print(f"标签1(分段): {label_1_count}")
# 【修改】:保存结果到CSV,文件名增加后缀以区分
result_df = pd.DataFrame(results)
output_file = "segmentation_labeling_results_from_7.csv"
result_df.to_csv(output_file, index=False, encoding='utf-8-sig')
print(f"\n结果已保存到: {output_file}")
# 显示前几条结果
print(f"\n前3条标注结果预览:")
for i, result in enumerate(results[:3]):
print(f"\n{i + 1}. Source ID: {result['source_id']}")
print(f" 句子1: {result['sentence1'][:50]}...")
print(f" 句子2: {result['sentence2'][:50]}...")
print(f" 标签: {result['label']}")
print(f" 理由: {result['reason']}")
# 【修改】:保存详细的JSON结果,文件名增加后缀以区分
json_output_file = 'segmentation_results_from_7.json'
with open(json_output_file, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"详细JSON结果已保存到: {json_output_file}")
# 【新增】:显示处理的source_id范围统计
if results:
processed_source_ids = sorted(set(r['source_id'] for r in results))
print(f"\n实际处理的source_id范围: {min(processed_source_ids)} - {max(processed_source_ids)}")
print(f"共处理了 {len(processed_source_ids)} 个source_id")

581
4.结构化json-Ai标注/Ai-continue.py

@ -0,0 +1,581 @@
import requests
import json
import pandas as pd
from typing import List, Dict
import time
class SimpleOpenAIHubClient:
def __init__(self, api_key):
self.api_key = api_key
self.base_url = "https://api.openai-hub.com"
self.headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
def chat(self, prompt, model="gpt-4.1"):
"""发送prompt并返回模型回答"""
payload = {
"model": model,
"messages": [
{
"role": "user",
"content": prompt
}
],
"max_tokens": 32768,
"temperature": 0.7
}
try:
response = requests.post(
f"{self.base_url}/v1/chat/completions",
headers=self.headers,
json=payload,
timeout=60
)
if response.status_code == 200:
result = response.json()
return result['choices'][0]['message']['content']
else:
return f"错误: {response.status_code} - {response.text}"
except requests.exceptions.RequestException as e:
return f"请求异常: {str(e)}"
print("AI客户端类定义完成!")
# 设置API Key
API_KEY = "sk-XREp2jnIXyZ6UoCnzZeO0ahmLi9OEXuVAtFLojKFpG9gCZ4e" # 请替换为你的实际API Key
# 初始化AI客户端
client = SimpleOpenAIHubClient(API_KEY)
print("AI模型加载完成!")
# 定义批量标注的Prompt模板
BATCH_SEGMENTATION_PROMPT = """你是一个专业的广播内容段落分割标注员。你的任务是批量判断多个相邻句子对之间是否应该进行段落分割,以便广播员更好地掌握停顿和语调变化。
**完整文本内容上下文**
{context_text}
**标注规则**
- 标签0两个句子属于同一段落连续播报轻微停顿
- 标签1两个句子属于不同段落需要明显停顿或语调转换
**重要标注要求请严格遵循**
- 如果整个文本内容都在讲同一个事你有理由只输出一段不是追求分的段越多越细就越好
- 每个分段必须保持原始语句的绝对顺序
- 最终分段数可能等于或小于原始语句数量
- 必须保留所有原始语句文本不得遗漏任何内容
- 应客户强烈要求他们需要的是较粗的分段不要太细如同一条通告不需要分段成具体的每个条款之类的只需要将整个相同的通告分成一段
- 优先考虑较粗的分段避免过度细分
**广播分段判断依据偏向粗分段**
1. **重大主题转换**从一个完全不同的话题转向另一个话题如从天气预报转向安全通知
2. **文档类型变化**从一个完整文档转向另一个完整文档如从禁火令转向倡议书
3. **内容性质变化**从通知类内容转向完全不同性质的内容如从法规转向天气预报
4. **广播节目段落**明显的广播节目结构变化如开场白结束进入正式内容
5. **分点阐述结构**标题和所有分点条目内容应该合并为一个完整段落"森林防火十不准,一不乱扔烟头,二不随意丢弃火种,三不在林区吸烟"等整体合成一段
**广播内容特别注意粗分段原则**
- 整个通告法令倡议书等应作为一个段落不要拆分条款
- 同一主题的多个条款应保持在同一段落
- 只有在完全不同的文档或重大主题转换时才分段
- 广播开场白可以独立成段但具体内容尽量合并
- 同一类型的预报信息如天气预报的不同地区应保持在同一段
- **分点阐述内容的特殊处理**
- 标题性内容"森林防火十不准"与分点条目内容之间不需要分段
- 标题和所有的分点条目"一不乱扔烟头""二不随意丢弃火种""三不在林区吸烟"应该合并为一个完整段落
- 分点条目之间不需要分段应该连续播报
- 整个分点阐述结构作为一个完整的内容单元保持连贯性
**批量标注说明**
- 每个句子对都有一个source_id表示来源文档
- 请保持原有的source_id不变
- 将label从-1改为实际的标注结果0或1
- 为每个句子对提供简要的分段理由
- 结合上述完整文本内容理解句子对的上下文语境
- **特别重要倾向于标注更多的0同段落减少1分段的使用分点阐述结构应保持为一个完整段落**
现在请对以下句子对进行批量标注
{batch_sentence_pairs}
请直接输出标注结果格式如下
```json
[
{{
"sentence1": "...",
"sentence2": "...",
"label": 0或1,
"reason": "广播分段理由",
"source_id": 原有的source_id
}}
]
```
只输出JSON数据不要其他说明文字"""
def load_context_data(csv_file="batch_deduplication_results_619-1103_01.csv"):
"""
从batch_deduplication_results_619-1103_01.csv加载上下文数据
Args:
csv_file: CSV文件路径
Returns:
字典key为idvalue为final_processed_text
"""
try:
print(f"正在读取上下文数据文件: {csv_file}")
# 尝试不同的编码格式
encodings = ['utf-8', 'gbk', 'gb2312', 'utf-8-sig', 'latin-1', 'cp1252']
context_df = None
for encoding in encodings:
try:
print(f" 尝试使用 {encoding} 编码...")
context_df = pd.read_csv(csv_file, encoding=encoding)
print(f" ✓ 成功使用 {encoding} 编码读取文件")
break
except UnicodeDecodeError:
print(f" {encoding} 编码失败")
continue
except Exception as e:
print(f" {encoding} 编码读取时出现其他错误:{str(e)}")
continue
if context_df is None:
print(f"✗ 错误:尝试了所有编码格式都无法读取文件 {csv_file}")
return {}
if 'id' not in context_df.columns or 'final_processed_text' not in context_df.columns:
print(f"✗ 错误:CSV文件缺少必需列")
print(f" 需要的列: ['id', 'final_processed_text']")
print(f" 实际的列: {list(context_df.columns)}")
return {}
# 创建id到final_processed_text的映射
context_dict = {}
for _, row in context_df.iterrows():
context_dict[row['id']] = row['final_processed_text'] if pd.notna(row['final_processed_text']) else ""
print(f"✓ 成功加载上下文数据")
print(f" - 可用ID数量: {len(context_dict)}")
print(f" - 可用ID列表: {sorted(context_dict.keys())}")
return context_dict
except FileNotFoundError:
print(f"✗ 警告:找不到上下文文件 {csv_file}")
print(" 将在没有上下文的情况下进行标注")
return {}
except Exception as e:
print(f"✗ 读取上下文文件时出错: {str(e)}")
print(" 将在没有上下文的情况下进行标注")
return {}
def load_failed_data_from_json(json_file="segmentation_results_from_7.json"):
"""
从JSON结果文件中加载标注失败的数据
Args:
json_file: JSON结果文件路径
Returns:
失败数据列表
"""
try:
print(f"正在读取JSON结果文件: {json_file}")
with open(json_file, 'r', encoding='utf-8') as f:
all_results = json.load(f)
print(f"✓ 成功加载JSON文件")
print(f" - 总结果数量: {len(all_results)}")
# 筛选出失败的数据(label为-1)
failed_data = [item for item in all_results if item.get('label') == -1]
successful_data = [item for item in all_results if item.get('label') in [0, 1]]
print(f" - 成功标注数量: {len(successful_data)}")
print(f" - 失败标注数量: {len(failed_data)}")
if len(failed_data) == 0:
print("✓ 没有发现失败的标注数据,无需重新处理")
return [], all_results
# 统计失败原因
from collections import Counter
failure_reasons = Counter(item.get('reason', '未知错误') for item in failed_data)
print(f"\n失败原因统计:")
for reason, count in failure_reasons.most_common():
print(f" - {reason}: {count}")
# 统计涉及的source_id
failed_source_ids = sorted(set(item.get('source_id') for item in failed_data))
print(f"\n涉及的source_id: {failed_source_ids}")
return failed_data, all_results
except FileNotFoundError:
print(f"✗ 错误:找不到文件 {json_file}")
return [], []
except json.JSONDecodeError as e:
print(f"✗ 错误:JSON文件格式错误 - {str(e)}")
return [], []
except Exception as e:
print(f"✗ 错误:读取JSON文件时出现异常 - {str(e)}")
return [], []
def convert_failed_data_to_sentence_pairs(failed_data):
"""
将失败数据转换为句子对格式供重新标注使用
Args:
failed_data: 失败数据列表
Returns:
句子对格式的数据列表
"""
sentence_pairs_data = []
for item in failed_data:
sentence_pair = {
"sentence1": item.get("sentence1", ""),
"sentence2": item.get("sentence2", ""),
"source_id": item.get("source_id"),
"label": -1 # 标记为待标注
}
sentence_pairs_data.append(sentence_pair)
return sentence_pairs_data
def process_batch_segmentation(sentence_pairs_data, context_dict, batch_size=8):
"""
批量处理句子对的段落分割标注
Args:
sentence_pairs_data: 句子对数据列表
context_dict: 上下文数据字典
batch_size: 每批处理的数量
Returns:
处理结果列表
"""
all_results = []
total_pairs = len(sentence_pairs_data)
print(f"开始批量标注,总共 {total_pairs} 个句子对")
print(f"每批处理 {batch_size} 个句子对")
# 分批处理
for i in range(0, total_pairs, batch_size):
batch_end = min(i + batch_size, total_pairs)
current_batch = sentence_pairs_data[i:batch_end]
print(f"\n处理第 {i // batch_size + 1} 批 (句子对 {i + 1}-{batch_end})")
try:
# 获取当前批次涉及的source_id的上下文
source_ids_in_batch = set(pair['source_id'] for pair in current_batch)
context_text = ""
for source_id in sorted(source_ids_in_batch):
if source_id in context_dict and context_dict[source_id]:
context_text += f"\n--- Source ID {source_id} 完整文本内容 ---\n"
context_text += context_dict[source_id] # 完整内容,不截断
context_text += "\n"
else:
context_text += f"\n--- Source ID {source_id} ---\n(未找到对应的完整文本内容)\n"
# 准备当前批次的数据
batch_json = json.dumps(current_batch, ensure_ascii=False, indent=2)
# 构建prompt
prompt = BATCH_SEGMENTATION_PROMPT.format(
context_text=context_text,
batch_sentence_pairs=batch_json
)
print(f"发送请求到AI模型...")
print(f" - 涉及source_id: {sorted(source_ids_in_batch)}")
print(f" - 上下文长度: {len(context_text)} 字符")
print(f" - Prompt总长度: {len(prompt)} 字符")
# 调用AI模型
ai_response = client.chat(prompt)
print(f"收到模型响应")
# 尝试解析JSON响应
try:
# 提取JSON部分(去除可能的markdown格式)
json_start = ai_response.find('[')
json_end = ai_response.rfind(']') + 1
if json_start != -1 and json_end != 0:
json_content = ai_response[json_start:json_end]
batch_results = json.loads(json_content)
# 验证结果
if isinstance(batch_results, list) and len(batch_results) == len(current_batch):
all_results.extend(batch_results)
print(f"✓ 成功处理 {len(batch_results)} 个句子对")
else:
print(
f"✗ 响应格式不正确,期望 {len(current_batch)} 个结果,实际得到 {len(batch_results) if isinstance(batch_results, list) else 'non-list'}")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": "重试失败:响应格式错误",
"source_id": pair["source_id"]
})
else:
print(f"✗ 无法找到有效的JSON响应")
print(f"原始响应前200字符: {ai_response[:200]}...")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": "重试失败:JSON解析错误",
"source_id": pair["source_id"]
})
except json.JSONDecodeError as e:
print(f"✗ JSON解析失败: {str(e)}")
print(f"原始响应: {ai_response[:200]}...")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"重试失败:{str(e)}",
"source_id": pair["source_id"]
})
# 添加延时,避免API调用过于频繁
time.sleep(2)
except Exception as e:
print(f"✗ 批次处理出错: {str(e)}")
# 添加错误记录
for j, pair in enumerate(current_batch):
all_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"重试异常:{str(e)}",
"source_id": pair["source_id"]
})
return all_results
def merge_results(original_results, retry_results):
"""
合并原始结果和重试结果
Args:
original_results: 原始完整结果列表
retry_results: 重试结果列表
Returns:
合并后的完整结果列表
"""
print("正在合并结果...")
# 创建重试结果的映射,用于快速查找
retry_map = {}
for result in retry_results:
key = (result['sentence1'], result['sentence2'], result['source_id'])
retry_map[key] = result
merged_results = []
replaced_count = 0
for original_result in original_results:
key = (original_result['sentence1'], original_result['sentence2'], original_result['source_id'])
# 如果原始结果是失败的,并且重试结果存在,则用重试结果替换
if original_result.get('label') == -1 and key in retry_map:
merged_results.append(retry_map[key])
replaced_count += 1
else:
merged_results.append(original_result)
print(f"✓ 合并完成,替换了 {replaced_count} 个失败结果")
return merged_results
# 执行重新标注失败数据
print("=" * 60)
print("开始重新标注失败数据")
print("=" * 60)
# 加载上下文数据
context_dict = load_context_data("batch_deduplication_results_619-1103_01.csv")
# 从JSON文件加载失败数据
failed_data, original_results = load_failed_data_from_json("segmentation_results_from_7.json")
if len(failed_data) == 0:
print("没有需要重新标注的数据,程序结束。")
exit()
print(f"\n开始重新标注 {len(failed_data)} 个失败的句子对...")
# 将失败数据转换为句子对格式
retry_sentence_pairs = convert_failed_data_to_sentence_pairs(failed_data)
# 按source_id分组处理(与原始代码保持一致的处理方式)
from collections import defaultdict
grouped_by_source_id = defaultdict(list)
for pair in retry_sentence_pairs:
grouped_by_source_id[pair['source_id']].append(pair)
print(f"失败数据涉及 {len(grouped_by_source_id)} 个source_id")
# 处理每个source_id的失败数据
all_retry_results = []
for source_id in sorted(grouped_by_source_id.keys()):
current_sentence_pairs = grouped_by_source_id[source_id]
print(f"\n{'=' * 50}")
print(f"重新处理 source_id: {source_id}")
print(f"{'=' * 50}")
print(f" - 该source_id的失败句子对数量: {len(current_sentence_pairs)}")
# 显示第一个句子对预览
if len(current_sentence_pairs) > 0:
first_pair = current_sentence_pairs[0]
print(f" - 第一个句子对预览:")
print(f" 句子1: {first_pair['sentence1'][:60]}...")
print(f" 句子2: {first_pair['sentence2'][:60]}...")
# 检查上下文数据
if source_id in context_dict and context_dict[source_id]:
print(f" - 上下文数据: 可用,长度 {len(context_dict[source_id])} 字符")
print(f" - 上下文预览: {context_dict[source_id][:100]}...")
else:
print(f" - 上下文数据: 不可用")
try:
print(f" - 开始重新标注...")
# 执行批量标注(batch_size=8)
current_results = process_batch_segmentation(current_sentence_pairs, context_dict, batch_size=8)
# 添加到总结果中
all_retry_results.extend(current_results)
# 统计当前source_id的结果
current_successful = len([r for r in current_results if r['label'] in [0, 1]])
current_failed = len([r for r in current_results if r['label'] == -1])
print(f" ✓ source_id {source_id} 重新标注完成")
print(f" - 成功标注: {current_successful}")
print(f" - 仍然失败: {current_failed}")
if current_successful > 0:
current_label_0 = len([r for r in current_results if r['label'] == 0])
current_label_1 = len([r for r in current_results if r['label'] == 1])
print(f" - 标签0(同段落): {current_label_0}")
print(f" - 标签1(分段): {current_label_1}")
# 添加延时
time.sleep(1)
except Exception as e:
print(f" ✗ source_id {source_id} 重新标注失败: {str(e)}")
# 为失败的source_id添加错误记录
for pair in current_sentence_pairs:
all_retry_results.append({
"sentence1": pair["sentence1"],
"sentence2": pair["sentence2"],
"label": -1,
"reason": f"source_id重试异常:{str(e)}",
"source_id": pair["source_id"]
})
print(f"\n{'=' * 60}")
print("重新标注完成!")
print(f"{'=' * 60}")
# 统计重试结果
retry_successful = len([r for r in all_retry_results if r['label'] in [0, 1]])
retry_failed = len([r for r in all_retry_results if r['label'] == -1])
print(f"重试结果统计:")
print(f" - 总重试数量: {len(all_retry_results)}")
print(f" - 重试成功: {retry_successful}")
print(f" - 重试仍失败: {retry_failed}")
if retry_successful > 0:
retry_label_0 = len([r for r in all_retry_results if r['label'] == 0])
retry_label_1 = len([r for r in all_retry_results if r['label'] == 1])
print(f" - 标签0(同段落): {retry_label_0}")
print(f" - 标签1(分段): {retry_label_1}")
# 合并原始结果和重试结果
final_results = merge_results(original_results, all_retry_results)
# 统计最终结果
final_successful = len([r for r in final_results if r['label'] in [0, 1]])
final_failed = len([r for r in final_results if r['label'] == -1])
print(f"\n最终结果统计:")
print(f" - 总数据量: {len(final_results)}")
print(f" - 成功标注: {final_successful}")
print(f" - 失败标注: {final_failed}")
print(f" - 成功率: {final_successful / len(final_results) * 100:.2f}%")
# 保存合并后的结果
final_result_df = pd.DataFrame(final_results)
final_csv_file = "segmentation_results_from_7_retried.csv"
final_result_df.to_csv(final_csv_file, index=False, encoding='utf-8-sig')
print(f"\n最终结果已保存到: {final_csv_file}")
# 保存详细的JSON结果
final_json_file = 'segmentation_results_from_7_retried.json'
with open(final_json_file, 'w', encoding='utf-8') as f:
json.dump(final_results, f, ensure_ascii=False, indent=2)
print(f"详细JSON结果已保存到: {final_json_file}")
# 显示重试前后对比
print(f"\n重试前后对比:")
original_successful = len([r for r in original_results if r['label'] in [0, 1]])
original_failed = len([r for r in original_results if r['label'] == -1])
print(f" - 重试前成功率: {original_successful / len(original_results) * 100:.2f}%")
print(f" - 重试后成功率: {final_successful / len(final_results) * 100:.2f}%")
print(f" - 成功率提升: {(final_successful - original_successful) / len(final_results) * 100:.2f}%")
# 显示前几条重试成功的结果
successful_retries = [r for r in all_retry_results if r['label'] in [0, 1]]
if len(successful_retries) > 0:
print(f"\n前3条重试成功的结果预览:")
for i, result in enumerate(successful_retries[:3]):
print(f"\n{i + 1}. Source ID: {result['source_id']}")
print(f" 句子1: {result['sentence1'][:50]}...")
print(f" 句子2: {result['sentence2'][:50]}...")
print(f" 标签: {result['label']}")
print(f" 理由: {result['reason']}")

192
4.结构化json-Ai标注/Data_plaus/Data_plaus.py

@ -0,0 +1,192 @@
import json
from collections import defaultdict
def create_cross_document_boundaries(input_file, output_file):
"""
创建跨文档边界的句子对数据
将source_id n的最后一句与source_id n+1的第一句配对标签设为1分段
"""
# 读取原始数据
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 按source_id分组数据
source_groups = defaultdict(list)
for item in data:
source_id = item['source_id']
source_groups[source_id].append(item)
# 按source_id排序
sorted_source_ids = sorted(source_groups.keys())
# 存储新创建的跨文档边界数据
cross_boundary_data = []
print(f"处理 {len(sorted_source_ids)} 个source_id...")
# 遍历相邻的source_id
for i in range(len(sorted_source_ids) - 1):
current_source_id = sorted_source_ids[i]
next_source_id = sorted_source_ids[i + 1]
current_group = source_groups[current_source_id]
next_group = source_groups[next_source_id]
if len(current_group) == 0 or len(next_group) == 0:
continue
# 获取当前source_id的最后一个句子对的sentence2
last_item = current_group[-1]
last_sentence = last_item['sentence2']
# 获取下一个source_id的第一个句子对的sentence1
first_item = next_group[0]
first_sentence = first_item['sentence1']
# 创建跨文档边界的句子对
cross_boundary_item = {
"sentence1": last_sentence,
"sentence2": first_sentence,
"label": 1, # 跨文档必须分段
"reason": f"跨文档边界: source_id {current_source_id} 的结尾与 source_id {next_source_id} 的开头,属于不同文档,必须分段。",
"source_id": f"{current_source_id}-{next_source_id}",
"boundary_type": "cross_document"
}
cross_boundary_data.append(cross_boundary_item)
print(f"创建跨界边界: {current_source_id} -> {next_source_id}")
print(f" 句子1: {last_sentence[:50]}...")
print(f" 句子2: {first_sentence[:50]}...")
print(f"\n总共创建了 {len(cross_boundary_data)} 个跨文档边界样本")
# 保存跨文档边界数据
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(cross_boundary_data, f, ensure_ascii=False, indent=2)
print(f"跨文档边界数据已保存到: {output_file}")
return cross_boundary_data
def merge_with_original_data(original_file, cross_boundary_file, merged_output_file):
"""
将跨文档边界数据与原始数据合并
"""
# 读取原始数据
with open(original_file, 'r', encoding='utf-8') as f:
original_data = json.load(f)
# 读取跨文档边界数据
with open(cross_boundary_file, 'r', encoding='utf-8') as f:
cross_boundary_data = json.load(f)
# 合并数据
merged_data = original_data + cross_boundary_data
print(f"原始数据: {len(original_data)}")
print(f"跨文档边界数据: {len(cross_boundary_data)}")
print(f"合并后数据: {len(merged_data)}")
# 统计标签分布
label_counts = {}
for item in merged_data:
label = item['label']
label_counts[label] = label_counts.get(label, 0) + 1
print(f"\n合并后标签分布:")
for label, count in label_counts.items():
label_name = "不分段" if label == 0 else "分段"
percentage = count / len(merged_data) * 100
print(f" 标签 {label} ({label_name}): {count} 条 ({percentage:.1f}%)")
# 保存合并数据
with open(merged_output_file, 'w', encoding='utf-8') as f:
json.dump(merged_data, f, ensure_ascii=False, indent=2)
print(f"\n合并数据已保存到: {merged_output_file}")
return merged_data
def analyze_source_structure(input_file):
"""
分析source_id的结构便于理解数据
"""
with open(input_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 按source_id分组
source_groups = defaultdict(list)
for item in data:
source_id = item['source_id']
source_groups[source_id].append(item)
print(f"数据结构分析:")
print(f"总共 {len(data)} 个句子对")
print(f"涉及 {len(source_groups)} 个source_id")
print(f"source_id范围: {min(source_groups.keys())} - {max(source_groups.keys())}")
# 显示每个source_id的句子对数量
print(f"\n各source_id的句子对数量:")
sorted_source_ids = sorted(source_groups.keys())
for source_id in sorted_source_ids:
count = len(source_groups[source_id])
print(f" source_id {source_id}: {count} 个句子对")
# 显示前几个source_id的示例
print(f"\n前3个source_id的示例:")
for source_id in sorted_source_ids[:3]:
items = source_groups[source_id]
print(f"\nsource_id {source_id}:")
print(f" 第一个句子对: {items[0]['sentence1'][:30]}... -> {items[0]['sentence2'][:30]}...")
if len(items) > 1:
print(f" 最后一个句子对: {items[-1]['sentence1'][:30]}... -> {items[-1]['sentence2'][:30]}...")
def main():
"""
主函数 - 处理跨文档边界数据
"""
# 文件路径
input_file = "segmentation_results_from_7_retried.json"
cross_boundary_output = "cross_document_boundaries.json"
merged_output = "enhanced_training_data_with_boundaries.json"
print("=" * 60)
print("跨文档边界数据生成")
print("=" * 60)
# 1. 分析原始数据结构
print("1. 分析原始数据结构...")
analyze_source_structure(input_file)
print("\n" + "=" * 60)
# 2. 创建跨文档边界数据
print("2. 创建跨文档边界数据...")
cross_boundary_data = create_cross_document_boundaries(input_file, cross_boundary_output)
print("\n" + "=" * 60)
# 3. 合并数据
print("3. 合并原始数据与跨文档边界数据...")
merged_data = merge_with_original_data(input_file, cross_boundary_output, merged_output)
print("\n" + "=" * 60)
print("处理完成!")
print("=" * 60)
print(f"生成的文件:")
print(f" - 跨文档边界数据: {cross_boundary_output}")
print(f" - 增强训练数据: {merged_output}")
print(f"\n现在可以使用 {merged_output} 进行模型训练")
if __name__ == "__main__":
main()

8418
4.结构化json-Ai标注/Data_plaus/cross_document_boundaries.json

File diff suppressed because it is too large Load Diff

340806
4.结构化json-Ai标注/Data_plaus/enhanced_training_data_with_boundaries.json

File diff suppressed because it is too large Load Diff

332390
4.结构化json-Ai标注/Data_plaus/segmentation_results_from_7_retried.json

File diff suppressed because it is too large Load Diff

6782
4.结构化json-Ai标注/all_sentence_pairs_for_annotation.json

File diff suppressed because it is too large Load Diff

542
4.结构化json-Ai标注/batch_deduplication_report_619-1103_01.txt

@ -0,0 +1,542 @@
=== 批量多级别去重详细报告 ===
处理日期: 2025-08-11 17:04:28
总共处理: 50 个ID
成功处理: 50 个ID
总体统计:
- 平均压缩比: 24.59%
- 总原始字符数: 108,025
- 总最终字符数: 57,951
- 总减少字符数: 50,038
--- ID 1104 详细报告 ---
原始文本长度: 791 字符
最终文本长度: 791 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1105 详细报告 ---
原始文本长度: 791 字符
最终文本长度: 791 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1106 详细报告 ---
原始文本长度: 7591 字符
最终文本长度: 801 字符
总体压缩比: 89.45%
各级别处理效果:
1. 段落级去重: 减少 6791 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 161 项内容
--- ID 1107 详细报告 ---
原始文本长度: 19 字符
最终文本长度: 19 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1108 详细报告 ---
原始文本长度: 3738 字符
最终文本长度: 1248 字符
总体压缩比: 66.61%
各级别处理效果:
1. 段落级去重: 减少 2491 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 60 项内容
--- ID 1109 详细报告 ---
原始文本长度: 4841 字符
最终文本长度: 4841 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1110 详细报告 ---
原始文本长度: 177 字符
最终文本长度: 104 字符
总体压缩比: 41.24%
各级别处理效果:
1. 段落级去重: 减少 74 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1111 详细报告 ---
原始文本长度: 212 字符
最终文本长度: 212 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1112 详细报告 ---
原始文本长度: 190 字符
最终文本长度: 116 字符
总体压缩比: 38.95%
各级别处理效果:
1. 段落级去重: 减少 75 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1113 详细报告 ---
原始文本长度: 1282 字符
最终文本长度: 1282 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1114 详细报告 ---
原始文本长度: 5262 字符
最终文本长度: 5262 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1115 详细报告 ---
原始文本长度: 5328 字符
最终文本长度: 2005 字符
总体压缩比: 62.37%
各级别处理效果:
1. 段落级去重: 减少 2707 字符
2. 句子级去重: 减少 616 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 33 项内容
SENTENCES级别移除了 7 项内容
--- ID 1116 详细报告 ---
原始文本长度: 5127 字符
最终文本长度: 5117 字符
总体压缩比: 0.20%
各级别处理效果:
1. 段落级去重: 减少 11 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1117 详细报告 ---
原始文本长度: 400 字符
最终文本长度: 400 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1118 详细报告 ---
原始文本长度: 1296 字符
最终文本长度: 817 字符
总体压缩比: 36.96%
各级别处理效果:
1. 段落级去重: 减少 480 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 11 项内容
--- ID 1119 详细报告 ---
原始文本长度: 445 字符
最终文本长度: 284 字符
总体压缩比: 36.18%
各级别处理效果:
1. 段落级去重: 减少 162 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 2 项内容
--- ID 1120 详细报告 ---
原始文本长度: 795 字符
最终文本长度: 422 字符
总体压缩比: 46.92%
各级别处理效果:
1. 段落级去重: 减少 374 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 10 项内容
--- ID 1121 详细报告 ---
原始文本长度: 796 字符
最终文本长度: 424 字符
总体压缩比: 46.73%
各级别处理效果:
1. 段落级去重: 减少 373 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 10 项内容
--- ID 1122 详细报告 ---
原始文本长度: 125 字符
最终文本长度: 125 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1123 详细报告 ---
原始文本长度: 37 字符
最终文本长度: 37 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1124 详细报告 ---
原始文本长度: 3675 字符
最终文本长度: 3175 字符
总体压缩比: 13.61%
各级别处理效果:
1. 段落级去重: 减少 501 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 14 项内容
--- ID 1125 详细报告 ---
原始文本长度: 498 字符
最终文本长度: 249 字符
总体压缩比: 50.00%
各级别处理效果:
1. 段落级去重: 减少 250 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1126 详细报告 ---
原始文本长度: 2461 字符
最终文本长度: 486 字符
总体压缩比: 80.25%
各级别处理效果:
1. 段落级去重: 减少 1976 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 40 项内容
--- ID 1127 详细报告 ---
原始文本长度: 2442 字符
最终文本长度: 1120 字符
总体压缩比: 54.14%
各级别处理效果:
1. 段落级去重: 减少 1323 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 15 项内容
--- ID 1128 详细报告 ---
原始文本长度: 2560 字符
最终文本长度: 1779 字符
总体压缩比: 30.51%
各级别处理效果:
1. 段落级去重: 减少 782 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 8 项内容
--- ID 1129 详细报告 ---
原始文本长度: 2561 字符
最终文本长度: 1788 字符
总体压缩比: 30.18%
各级别处理效果:
1. 段落级去重: 减少 774 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 7 项内容
--- ID 1130 详细报告 ---
原始文本长度: 673 字符
最终文本长度: 673 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1131 详细报告 ---
原始文本长度: 264 字符
最终文本长度: 264 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1132 详细报告 ---
原始文本长度: 1566 字符
最终文本长度: 1442 字符
总体压缩比: 7.92%
各级别处理效果:
1. 段落级去重: 减少 125 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 2 项内容
--- ID 1133 详细报告 ---
原始文本长度: 1559 字符
最终文本长度: 1559 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1134 详细报告 ---
原始文本长度: 2510 字符
最终文本长度: 356 字符
总体压缩比: 85.82%
各级别处理效果:
1. 段落级去重: 减少 2155 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 141 项内容
--- ID 1135 详细报告 ---
原始文本长度: 2530 字符
最终文本长度: 380 字符
总体压缩比: 84.98%
各级别处理效果:
1. 段落级去重: 减少 2151 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 158 项内容
--- ID 1136 详细报告 ---
原始文本长度: 251 字符
最终文本长度: 251 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1137 详细报告 ---
原始文本长度: 3153 字符
最终文本长度: 571 字符
总体压缩比: 81.89%
各级别处理效果:
1. 段落级去重: 减少 2583 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 147 项内容
--- ID 1138 详细报告 ---
原始文本长度: 917 字符
最终文本长度: 883 字符
总体压缩比: 3.71%
各级别处理效果:
1. 段落级去重: 减少 35 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1139 详细报告 ---
原始文本长度: 908 字符
最终文本长度: 857 字符
总体压缩比: 5.62%
各级别处理效果:
1. 段落级去重: 减少 52 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1140 详细报告 ---
原始文本长度: 2797 字符
最终文本长度: 1656 字符
总体压缩比: 40.79%
各级别处理效果:
1. 段落级去重: 减少 1142 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 25 项内容
--- ID 1141 详细报告 ---
原始文本长度: 800 字符
最终文本长度: 800 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1142 详细报告 ---
原始文本长度: 618 字符
最终文本长度: 598 字符
总体压缩比: 3.24%
各级别处理效果:
1. 段落级去重: 减少 21 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 1 项内容
--- ID 1143 详细报告 ---
原始文本长度: 1330 字符
最终文本长度: 732 字符
总体压缩比: 44.96%
各级别处理效果:
1. 段落级去重: 减少 599 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 11 项内容
--- ID 1144 详细报告 ---
原始文本长度: 22010 字符
最终文本长度: 1494 字符
总体压缩比: 93.21%
各级别处理效果:
1. 段落级去重: 减少 20517 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 342 项内容
--- ID 1145 详细报告 ---
原始文本长度: 42 字符
最终文本长度: 42 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1146 详细报告 ---
原始文本长度: 771 字符
最终文本长度: 771 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1147 详细报告 ---
原始文本长度: 1183 字符
最终文本长度: 1183 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1148 详细报告 ---
原始文本长度: 1184 字符
最终文本长度: 1184 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1149 详细报告 ---
原始文本长度: 3964 字符
最终文本长度: 3964 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
--- ID 1150 详细报告 ---
原始文本长度: 1263 字符
最终文本长度: 1191 字符
总体压缩比: 5.70%
各级别处理效果:
1. 段落级去重: 减少 73 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 3 项内容
--- ID 1151 详细报告 ---
原始文本长度: 1611 字符
最终文本长度: 1524 字符
总体压缩比: 5.40%
各级别处理效果:
1. 段落级去重: 减少 88 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 8 项内容
--- ID 1152 详细报告 ---
原始文本长度: 1810 字符
最终文本长度: 1046 字符
总体压缩比: 42.21%
各级别处理效果:
1. 段落级去重: 减少 765 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符
PARAGRAPHS级别移除了 16 项内容
--- ID 1153 详细报告 ---
原始文本长度: 835 字符
最终文本长度: 835 字符
总体压缩比: 0.00%
各级别处理效果:
1. 段落级去重: 减少 1 字符
2. 句子级去重: 减少 -1 字符
3. 短语级去重: 减少 0 字符
4. 最终标点规范化: 减少 0 字符

51
4.结构化json-Ai标注/batch_deduplication_results_619-1103_01.csv

File diff suppressed because one or more lines are too long

7912
4.结构化json-Ai标注/segmentation_results_from_7.json

File diff suppressed because it is too large Load Diff

332390
4.结构化json-Ai标注/segmentation_results_from_7_retried.json

File diff suppressed because it is too large Load Diff

44907
4.结构化json-Ai标注/test_dataset.json

File diff suppressed because it is too large Load Diff

219
4.结构化json-Ai标注/train and test.py

@ -0,0 +1,219 @@
import json
import random
from collections import defaultdict
import pandas as pd
def split_dataset_by_source_id(input_file, test_size=150, random_seed=42):
"""
根据source_id随机划分数据集为训练集和测试集
Args:
input_file: 输入的JSON文件路径
test_size: 测试集中source_id的数量
random_seed: 随机种子确保结果可重现
Returns:
train_data, test_data: 训练集和测试集数据
"""
# 设置随机种子
random.seed(random_seed)
print(f"正在读取文件: {input_file}")
try:
# 读取JSON文件
with open(input_file, 'r', encoding='utf-8') as f:
all_data = json.load(f)
print(f"✓ 成功读取文件,总记录数: {len(all_data)}")
# 按source_id分组
source_id_groups = defaultdict(list)
for item in all_data:
source_id_groups[item['source_id']].append(item)
# 获取所有unique的source_id
all_source_ids = list(source_id_groups.keys())
total_source_ids = len(all_source_ids)
print(f"✓ 发现 {total_source_ids} 个不同的source_id")
# 检查测试集大小是否合理
if test_size >= total_source_ids:
print(f"✗ 错误:测试集大小 ({test_size}) 大于等于总source_id数量 ({total_source_ids})")
print(f"建议将测试集大小设置为小于 {total_source_ids}")
return None, None
# 随机选择测试集的source_id
test_source_ids = random.sample(all_source_ids, test_size)
train_source_ids = [sid for sid in all_source_ids if sid not in test_source_ids]
print(f"✓ 随机选择了 {len(test_source_ids)} 个source_id作为测试集")
print(f"✓ 剩余 {len(train_source_ids)} 个source_id作为训练集")
# 构建训练集和测试集
train_data = []
test_data = []
for source_id in train_source_ids:
train_data.extend(source_id_groups[source_id])
for source_id in test_source_ids:
test_data.extend(source_id_groups[source_id])
print(f"\n=== 数据集划分结果 ===")
print(f"训练集:")
print(f" - Source ID数量: {len(train_source_ids)}")
print(f" - 记录数量: {len(train_data)}")
print(f"测试集:")
print(f" - Source ID数量: {len(test_source_ids)}")
print(f" - 记录数量: {len(test_data)}")
# 统计标签分布
def get_label_distribution(data, dataset_name):
label_counts = defaultdict(int)
for item in data:
label_counts[item['label']] += 1
print(f"\n{dataset_name}标签分布:")
for label, count in sorted(label_counts.items()):
percentage = (count / len(data) * 150) if len(data) > 0 else 0
print(f" 标签 {label}: {count} 条 ({percentage:.2f}%)")
return label_counts
train_labels = get_label_distribution(train_data, "训练集")
test_labels = get_label_distribution(test_data, "测试集")
# 显示选中的source_id
print(f"\n=== 测试集Source ID列表 ===")
print(f"测试集source_id: {sorted(test_source_ids)}")
print(f"\n=== 训练集Source ID列表 ===")
print(f"训练集source_id: {sorted(train_source_ids)}")
return train_data, test_data, train_source_ids, test_source_ids
except FileNotFoundError:
print(f"✗ 错误:找不到文件 {input_file}")
return None, None, None, None
except json.JSONDecodeError as e:
print(f"✗ 错误:JSON文件格式错误 - {str(e)}")
return None, None, None, None
except Exception as e:
print(f"✗ 错误:处理文件时出现异常 - {str(e)}")
return None, None, None, None
def save_dataset(data, filename, description):
"""
保存数据集到JSON文件
Args:
data: 要保存的数据
filename: 输出文件名
description: 数据集描述
"""
try:
with open(filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
print(f"{description}已保存到: {filename}")
return True
except Exception as e:
print(f"✗ 保存{description}时出错: {str(e)}")
return False
def create_summary_report(train_data, test_data, train_source_ids, test_source_ids):
"""
创建数据集划分的详细报告
"""
summary = {
"split_info": {
"total_source_ids": len(train_source_ids) + len(test_source_ids),
"train_source_ids": len(train_source_ids),
"test_source_ids": len(test_source_ids),
"total_records": len(train_data) + len(test_data),
"train_records": len(train_data),
"test_records": len(test_data)
},
"train_source_id_list": sorted(train_source_ids),
"test_source_id_list": sorted(test_source_ids),
"label_distribution": {
"train": {},
"test": {}
}
}
# 计算标签分布
for dataset_name, data in [("train", train_data), ("test", test_data)]:
label_counts = defaultdict(int)
for item in data:
label_counts[item['label']] += 1
summary["label_distribution"][dataset_name] = dict(label_counts)
# 保存报告
with open('dataset_split_summary.json', 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"✓ 数据集划分报告已保存到: dataset_split_summary.json")
# 主程序执行
if __name__ == "__main__":
print("=" * 60)
print("数据集划分程序")
print("=" * 60)
# 输入文件名
input_file = "segmentation_results_from_7_retried.json"
# 执行数据集划分
train_data, test_data, train_source_ids, test_source_ids = split_dataset_by_source_id(
input_file=input_file,
test_size=150,
random_seed=42
)
if train_data is not None and test_data is not None:
print(f"\n{'=' * 60}")
print("开始保存数据集文件")
print(f"{'=' * 60}")
# 保存训练集
train_success = save_dataset(train_data, "train_dataset.json", "训练集")
# 保存测试集
test_success = save_dataset(test_data, "test_dataset.json", "测试集")
if train_success and test_success:
# 创建详细报告
create_summary_report(train_data, test_data, train_source_ids, test_source_ids)
print(f"\n{'=' * 60}")
print("数据集划分完成!")
print(f"{'=' * 60}")
print("生成的文件:")
print("1. train_dataset.json - 训练集数据")
print("2. test_dataset.json - 测试集数据")
print("3. dataset_split_summary.json - 划分报告")
# 验证数据完整性
print(f"\n=== 数据完整性验证 ===")
original_count = len(train_data) + len(test_data)
print(f"原始数据总数: {original_count}")
print(f"训练集 + 测试集: {len(train_data)} + {len(test_data)} = {len(train_data) + len(test_data)}")
if len(train_data) + len(test_data) == original_count:
print("✓ 数据完整性验证通过")
else:
print("✗ 数据完整性验证失败")
else:
print("✗ 保存文件时出现错误")
else:
print("✗ 数据集划分失败")

287485
4.结构化json-Ai标注/train_dataset.json

File diff suppressed because it is too large Load Diff

344
4.结构化json-Ai标注/统计/Stat.py

@ -0,0 +1,344 @@
import json
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from collections import Counter
import numpy as np
import warnings
import re
# 忽略matplotlib警告
warnings.filterwarnings('ignore')
# 设置matplotlib后端(避免显示问题)
import matplotlib
matplotlib.use('Agg') # 使用非交互式后端
# 设置中文字体支持
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def diagnose_json_file(file_path):
"""
诊断JSON文件的问题
Args:
file_path (str): JSON文件路径
Returns:
dict: 诊断结果
"""
print(f"正在诊断文件:{file_path}")
print("=" * 50)
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
print(f"文件大小:{len(content)} 字符")
print(f"文件前100个字符:{content[:100]}")
print(f"文件后100个字符:{content[-100:]}")
# 检查是否为空文件
if not content.strip():
print("错误:文件为空")
return {"status": "empty", "content": content}
# 尝试解析JSON
try:
data = json.loads(content)
print("✓ JSON格式正确")
return {"status": "valid", "data": data}
except json.JSONDecodeError as e:
print(f"✗ JSON格式错误:{e}")
print(f"错误位置:行 {e.lineno}, 列 {e.colno}")
return {"status": "invalid", "error": str(e), "content": content}
except FileNotFoundError:
print(f"错误:找不到文件 {file_path}")
return {"status": "not_found"}
except Exception as e:
print(f"读取文件时出错:{e}")
return {"status": "error", "error": str(e)}
def try_fix_json(content):
"""
尝试修复常见的JSON格式问题
Args:
content (str): 原始文件内容
Returns:
list: 修复后的数据如果修复失败则返回None
"""
print("\n尝试修复JSON格式...")
# 常见修复方法
fixes = [
# 1. 如果是JSONL格式(每行一个JSON对象)
lambda x: [json.loads(line) for line in x.strip().split('\n') if line.strip()],
# 2. 如果缺少最外层的方括号
lambda x: json.loads('[' + x + ']'),
# 3. 如果有多个JSON对象但没有用逗号分隔
lambda x: json.loads('[' + re.sub(r'}\s*{', '},{', x) + ']'),
# 4. 如果有trailing comma
lambda x: json.loads(re.sub(r',\s*}', '}', re.sub(r',\s*]', ']', x))),
# 5. 如果单引号而非双引号
lambda x: json.loads(x.replace("'", '"')),
]
for i, fix_func in enumerate(fixes, 1):
try:
print(f"尝试修复方法 {i}...")
result = fix_func(content)
if isinstance(result, list) and len(result) > 0:
print(f"✓ 修复成功!找到 {len(result)} 条数据")
return result
elif isinstance(result, dict):
print(f"✓ 修复成功!找到 1 条数据")
return [result]
except Exception as e:
print(f"✗ 修复方法 {i} 失败:{e}")
print("所有修复方法都失败了")
return None
def load_and_analyze_json(file_path):
"""
加载JSON文件并统计标签分布包含错误处理和修复功能
Args:
file_path (str): JSON文件路径
Returns:
tuple: (标签统计结果, 总数)
"""
# 首先诊断文件
diagnosis = diagnose_json_file(file_path)
if diagnosis["status"] == "not_found":
return None, None
elif diagnosis["status"] == "empty":
print("文件为空,无法分析")
return None, None
elif diagnosis["status"] == "valid":
data = diagnosis["data"]
elif diagnosis["status"] == "invalid":
# 尝试修复
fixed_data = try_fix_json(diagnosis["content"])
if fixed_data is None:
print("无法修复JSON格式错误")
return None, None
data = fixed_data
else:
print(f"未知错误:{diagnosis.get('error', '未知')}")
return None, None
# 确保数据是列表格式
if not isinstance(data, list):
data = [data]
print(f"\n成功加载数据,共 {len(data)} 条记录")
# 检查数据结构
if len(data) == 0:
print("数据为空")
return None, None
# 检查第一条数据的结构
first_item = data[0]
print(f"第一条数据结构:{list(first_item.keys()) if isinstance(first_item, dict) else type(first_item)}")
# 提取标签
labels = []
for i, item in enumerate(data):
if isinstance(item, dict):
if 'label' in item:
labels.append(item['label'])
elif 'Label' in item:
labels.append(item['Label'])
else:
print(f"警告:第 {i + 1} 条数据缺少 'label' 字段:{item}")
else:
print(f"警告:第 {i + 1} 条数据不是字典格式:{item}")
if not labels:
print("错误:没有找到任何标签数据")
return None, None
# 统计标签数量
label_counts = Counter(labels)
total = len(labels)
# 打印统计结果
print("=" * 50)
print("标签统计结果:")
print("=" * 50)
print(f"总数据条数:{total}")
print("-" * 30)
for label, count in sorted(label_counts.items()):
percentage = (count / total) * 100
print(f"标签 {label}: {count:4d} 条 ({percentage:5.1f}%)")
return label_counts, total
def create_pie_chart(label_counts, total, save_path=None):
"""
创建扇形图
Args:
label_counts (dict): 标签统计结果
total (int): 总数据条数
save_path (str, optional): 保存图片的路径
"""
# 准备数据
labels = []
sizes = []
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD']
for label, count in sorted(label_counts.items()):
if label == 0:
labels.append(f'不分段 (Label {label})')
else:
labels.append(f'分段 (Label {label})')
sizes.append(count)
# 创建图形
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))
# 扇形图
wedges, texts, autotexts = ax1.pie(sizes, labels=labels, autopct='%1.1f%%',
colors=colors[:len(sizes)], startangle=90,
explode=[0.05] * len(sizes))
# 美化扇形图
ax1.set_title('文本分段标签分布统计', fontsize=16, fontweight='bold', pad=20)
# 调整文本样式
for autotext in autotexts:
autotext.set_color('white')
autotext.set_fontweight('bold')
autotext.set_fontsize(12)
for text in texts:
text.set_fontsize(11)
# 柱状图
ax2.bar(range(len(label_counts)), sizes, color=colors[:len(sizes)], alpha=0.7)
ax2.set_title('标签数量柱状图', fontsize=16, fontweight='bold', pad=20)
ax2.set_xlabel('标签类型', fontsize=12)
ax2.set_ylabel('数量', fontsize=12)
# 设置x轴标签
ax2.set_xticks(range(len(label_counts)))
ax2.set_xticklabels([f'Label {label}' for label in sorted(label_counts.keys())])
# 在柱状图上添加数值标签
for i, (label, count) in enumerate(sorted(label_counts.items())):
percentage = (count / total) * 100
ax2.text(i, count + total * 0.01, f'{count}\n({percentage:.1f}%)',
ha='center', va='bottom', fontweight='bold')
# 调整布局
plt.tight_layout()
# 保存图片
if save_path:
try:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n图片已保存到:{save_path}")
except Exception as e:
print(f"保存图片时出错:{e}")
print("图表生成完成,请查看保存的图片文件。")
# 关闭图形以释放内存
plt.close(fig)
def create_detailed_report(label_counts, total, file_path):
"""
创建详细报告
Args:
label_counts (dict): 标签统计结果
total (int): 总数据条数
file_path (str): 原始JSON文件路径
"""
report = []
report.append("=" * 60)
report.append("文本分段标签分布统计报告")
report.append("=" * 60)
report.append(f"数据源文件:{file_path}")
report.append(f"分析时间:{np.datetime64('now', 'D')}")
report.append(f"总数据条数:{total}")
report.append("")
report.append("标签分布详情:")
report.append("-" * 40)
for label, count in sorted(label_counts.items()):
percentage = (count / total) * 100
label_desc = "不分段" if label == 0 else "分段"
report.append(f"Label {label} ({label_desc}):{count:4d} 条 ({percentage:5.1f}%)")
report.append("")
report.append("标签含义说明:")
report.append("- Label 0:两句话不需要分段,属于同一段落")
report.append("- Label 1:两句话需要分段,属于不同段落")
# 打印报告
for line in report:
print(line)
# 保存报告到文件
report_file = file_path.replace('.json', '_analysis_report.txt')
try:
with open(report_file, 'w', encoding='utf-8') as f:
f.write('\n'.join(report))
print(f"\n详细报告已保存到:{report_file}")
except Exception as e:
print(f"保存报告时出错:{e}")
def main():
"""主函数"""
# JSON文件路径
json_file = 'test_dataset.json'
print("JSON文件分析工具 - 增强版")
print("=" * 50)
# 加载并分析数据
label_counts, total = load_and_analyze_json(json_file)
if label_counts is not None:
# 创建扇形图
image_path = json_file.replace('.json', '_pie_chart.png')
create_pie_chart(label_counts, total, image_path)
# 创建详细报告
create_detailed_report(label_counts, total, json_file)
print("\n分析完成!")
print("=" * 50)
else:
print("分析失败,请检查文件内容和格式。")
print("\n建议:")
print("1. 确保文件存在且不为空")
print("2. 检查JSON格式是否正确")
print("3. 确保每条数据都有'label'字段")
print("4. 如果是JSONL格式,确保每行都是有效的JSON对象")
if __name__ == "__main__":
main()

293
4.结构化json-Ai标注/统计/Token_Stat.py

@ -0,0 +1,293 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
BERT Token数量统计与可视化
统计sentence1和最后一个sentence2的token数量分布
"""
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from collections import Counter
from transformers import AutoTokenizer
import warnings
# 忽略transformers的警告
warnings.filterwarnings("ignore")
# 设置matplotlib后端,避免显示问题
plt.switch_backend('Agg')
def load_sentence_pairs(file_path):
"""加载句子对数据"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"成功加载 {len(data)} 个句子对")
return data
except FileNotFoundError:
print(f"错误:找不到文件 {file_path}")
return None
except json.JSONDecodeError:
print(f"错误:JSON文件格式错误")
return None
except Exception as e:
print(f"加载文件时发生错误:{e}")
return None
def initialize_tokenizer(model_name="bert-base-chinese"):
"""初始化BERT tokenizer"""
try:
print(f"初始化 {model_name} tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Tokenizer初始化成功")
return tokenizer
except Exception as e:
print(f"初始化tokenizer失败:{e}")
print("尝试使用备用tokenizer...")
try:
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
print("成功使用多语言BERT tokenizer")
return tokenizer
except Exception as e2:
print(f"备用tokenizer也失败:{e2}")
return None
def count_bert_tokens(text, tokenizer):
"""计算文本的BERT token数量(不包含特殊token)"""
if not text or text.strip() == "":
return 0
try:
# 使用tokenizer编码文本,不添加特殊token
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
except Exception as e:
print(f"计算token时出错:{e}")
return 0
def get_token_range_label(token_count):
"""根据token数量获取对应的区间标签"""
range_start = (token_count // 100) * 100
range_end = range_start + 99
return f"{range_start}-{range_end}"
def analyze_token_distribution(sentence_pairs, tokenizer):
"""分析token分布"""
print("\n开始分析token分布...")
# 收集所有sentence1的token数量和对应的source_id
sentence1_tokens = []
token_details = [] # 存储详细信息:(token_count, source_id, sentence_type, sentence_text)
for pair in sentence_pairs:
sentence1 = pair.get('sentence1', '')
source_id = pair.get('source_id', 'unknown')
token_count = count_bert_tokens(sentence1, tokenizer)
sentence1_tokens.append(token_count)
token_details.append((token_count, source_id, 'sentence1', sentence1))
# 获取最后一个句子对的sentence2
last_sentence2_tokens = 0
if sentence_pairs:
last_pair = sentence_pairs[-1]
last_sentence2 = last_pair.get('sentence2', '')
last_source_id = last_pair.get('source_id', 'unknown')
last_sentence2_tokens = count_bert_tokens(last_sentence2, tokenizer)
if last_sentence2_tokens > 0:
token_details.append((last_sentence2_tokens, last_source_id, 'sentence2', last_sentence2))
print(f"处理了 {len(sentence1_tokens)} 个sentence1")
print(f"最后一个sentence2的token数量: {last_sentence2_tokens}")
return sentence1_tokens, last_sentence2_tokens, token_details
def create_token_distribution_chart(sentence1_tokens, last_sentence2_tokens):
"""创建token分布柱状图"""
print("\n生成token分布图...")
# 合并所有需要统计的token数量
all_tokens = sentence1_tokens + [last_sentence2_tokens] if last_sentence2_tokens > 0 else sentence1_tokens
# 计算最大token数量以确定区间范围
max_tokens = max(all_tokens) if all_tokens else 0
max_range = ((max_tokens // 100) + 1) * 100
# 创建区间
ranges = []
range_labels = []
for i in range(0, max_range, 100):
ranges.append((i, i + 99))
range_labels.append(f"{i}-{i + 99}")
# 统计每个区间的句子数量
range_counts = [0] * len(ranges)
for token_count in all_tokens:
range_index = token_count // 100
if range_index < len(range_counts):
range_counts[range_index] += 1
# 创建图表
plt.figure(figsize=(12, 8))
# 创建柱状图
bars = plt.bar(range_labels, range_counts, color='skyblue', edgecolor='navy', alpha=0.7)
# 设置图表属性
plt.title('BERT Token Count Distribution', fontsize=16, fontweight='bold')
plt.xlabel('Token Count Range', fontsize=12)
plt.ylabel('Number of Sentences', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)
# 在柱子上添加数值标签
for bar, count in zip(bars, range_counts):
if count > 0: # 只在有数据的柱子上显示标签
plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
str(count), ha='center', va='bottom', fontsize=10)
# 调整布局
plt.tight_layout()
# 显示统计信息
total_sentences = len(all_tokens)
avg_tokens = np.mean(all_tokens) if all_tokens else 0
median_tokens = np.median(all_tokens) if all_tokens else 0
# 在图表上添加统计信息文本框
stats_text = f'Total Sentences: {total_sentences}\n'
stats_text += f'Average Tokens: {avg_tokens:.1f}\n'
stats_text += f'Median Tokens: {median_tokens:.1f}\n'
stats_text += f'Max Tokens: {max_tokens}'
plt.text(0.02, 0.98, stats_text, transform=plt.gca().transAxes,
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
verticalalignment='top', fontsize=10)
return plt
def find_high_token_sentences(token_details, threshold=300):
"""查找token数量超过阈值的句子"""
print(f"\n=== Token数量超过{threshold}的句子 ===")
high_token_sentences = [(count, source_id, sentence_type, sentence)
for count, source_id, sentence_type, sentence in token_details
if count > threshold]
if not high_token_sentences:
print(f"没有找到token数量超过{threshold}的句子")
return []
# 按token数量降序排列
high_token_sentences.sort(key=lambda x: x[0], reverse=True)
print(f"找到 {len(high_token_sentences)} 个token数量超过{threshold}的句子:")
print("-" * 80)
for i, (token_count, source_id, sentence_type, sentence) in enumerate(high_token_sentences, 1):
print(f"{i}. Source ID: {source_id}")
print(f" Type: {sentence_type}")
print(f" Token Count: {token_count}")
print(f" Content: {sentence[:100]}{'...' if len(sentence) > 100 else ''}")
print("-" * 80)
# 保存到CSV文件
import pandas as pd
df_high_tokens = pd.DataFrame(high_token_sentences,
columns=['token_count', 'source_id', 'sentence_type', 'sentence_text'])
output_file = f'high_token_sentences_over_{threshold}.csv'
df_high_tokens.to_csv(output_file, index=False, encoding='utf-8-sig')
print(f"详细信息已保存到: {output_file}")
return high_token_sentences
"""打印详细统计信息"""
print("\n=== 详细统计信息 ===")
all_tokens = sentence1_tokens + [last_sentence2_tokens] if last_sentence2_tokens > 0 else sentence1_tokens
if not all_tokens:
print("没有数据可统计")
return
print(f"Sentence1总数: {len(sentence1_tokens)}")
print(f"Last Sentence2: {'已包含' if last_sentence2_tokens > 0 else '无数据'}")
print(f"总句子数: {len(all_tokens)}")
print(f"平均token数: {np.mean(all_tokens):.2f}")
print(f"中位数token数: {np.median(all_tokens):.2f}")
print(f"最小token数: {min(all_tokens)}")
print(f"最大token数: {max(all_tokens)}")
print(f"标准差: {np.std(all_tokens):.2f}")
# 按区间统计
print("\n=== 区间分布 ===")
max_tokens = max(all_tokens)
max_range = ((max_tokens // 100) + 1) * 100
for i in range(0, max_range, 100):
count = sum(1 for x in all_tokens if i <= x < i + 100)
if count > 0:
percentage = (count / len(all_tokens)) * 100
print(f"{i}-{i + 99} tokens: {count} 句子 ({percentage:.1f}%)")
def main():
"""主函数"""
# 文件路径
input_file = 'segmentation_results_from_7_retried.json'
# 1. 加载数据
sentence_pairs = load_sentence_pairs(input_file)
if sentence_pairs is None:
return
# 2. 初始化tokenizer
tokenizer = initialize_tokenizer("bert-base-chinese")
if tokenizer is None:
print("无法初始化tokenizer,程序退出")
return
# 3. 分析token分布
sentence1_tokens, last_sentence2_tokens, token_details = analyze_token_distribution(sentence_pairs, tokenizer)
if not sentence1_tokens:
print("没有找到有效的句子数据")
return
# 4. 查找高token数量的句子
high_token_sentences = find_high_token_sentences(token_details, threshold=300)
# 5. 打印详细统计
# print_detailed_statistics(sentence1_tokens, last_sentence2_tokens)
# 6. 创建可视化图表
plt = create_token_distribution_chart(sentence1_tokens, last_sentence2_tokens)
# 7. 保存和显示图表
try:
output_file = 'bert_token_distribution.png'
plt.savefig(output_file, dpi=300, bbox_inches='tight')
print(f"\n图表已保存为: {output_file}")
plt.show()
except Exception as e:
print(f"保存或显示图表时出错: {e}")
# 尝试不显示图表,只保存
try:
plt.savefig('bert_token_distribution.png', dpi=300, bbox_inches='tight')
print("图表已保存,但无法显示")
except Exception as e2:
print(f"保存图表也失败: {e2}")
print("\n分析完成!")
if __name__ == "__main__":
main()

6782
4.结构化json-Ai标注/统计/all_sentence_pairs_for_annotation.json

File diff suppressed because it is too large Load Diff

BIN
4.结构化json-Ai标注/统计/bert_token_distribution.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 189 KiB

11
4.结构化json-Ai标注/统计/high_token_sentences_over_300.csv

@ -0,0 +1,11 @@
token_count,source_id,sentence_type,sentence_text
402,994,sentence1,目前,国家级强对流天气预警标准为强对流天气蓝色预警,预计未来24小时,三个及以上相邻省区市部分地区将出现八级以上雷暴大风或冰雹,且局地可能出现十级雷暴大风或直径10毫米以上冰雹,或者20毫米每小时以上强度的短时强降水,并伴随成片五个县级行政区以上雷暴大风或冰雹,或者已经出现并可能持续强对流天气黄色预警,预计未来24小时,三个及以上相邻省区市部分地区将出现九级以上雷暴大风或冰雹,且局地可能出现11级雷暴大风或直径20毫米以上冰雹或龙卷风,或者50毫米每小时以上强度的短时强降水,并伴随成片五个县级行政区以上雷暴大风或冰雹,或者已经出现并可能持续强对流天气橙色预警,预计未来24小时,三个及以上相邻省区市部分地区将出现十级以上雷暴大风或冰雹,且局地可能出现12级以上雷暴大风或直径50毫米以上冰雹或强龙卷风,或者80毫米每小时以上强度的短时强降水,并伴随成片五个县级行政区以上雷暴大风或冰雹,或者已经出现并可能持续强对流天气。
381,252,sentence1,提醒家人亲友远离酒后驾驶,闯红灯、超员、超速、超限、超载等违法行为,驾驶农用车、拖拉机、三轮摩托车、电动自行车,不违法载人,文明礼让,不开斗气车,做文明交通的驾驶者,未满12周岁,不骑自行车上路,未满16周岁,不骑电动车上路,做到7个自觉,自觉养成按灯停走,按道行驶,按线通行,按位停放,按章驾乘,自觉系安全带,自觉减速让行做到10步,开车不接驳手机,不争道抢道占道,不随意变道,掉头,不车窗抛物,不乱停乱放,不乱鸣喇叭,不酒后驾驶,不超速超载,不疲劳驾驶,不买卖,伪造变造,故意遮挡、污损号牌,做文明交通的行路者,行人要走人行道,过马路走斑马线,地下通道和人行天桥要优先,不跨护栏,不闯红灯,不占盲道,不在车道内穿行,不低头玩手机,不在道路上嬉戏追逐玩耍,做文明交通的骑行者,骑乘自行车、电动自行车、摩托车,正确佩戴头盔,不加装雨棚、遮阳伞等影响行车安全的附属设施。
377,234,sentence1,中华人民共和国道路交通安全法第第九十一条,饮酒后驾驶机动车的处暂扣6个月机动车驾驶证,并处1000元以上2000元以下罚款,因饮酒后驾驶机动车被处罚,再次饮酒后驾驶机动车的处10日以下拘留,并处1000元以上2000元以下罚款,吊销机动车驾驶证,醉酒驾驶机动车的,由公安机关交通管理部门约束至酒醒,吊销机动车驾驶证,依法追究刑事责任,5年内不得重新取得机动车驾驶证,饮酒后驾驶营运机动车的,处15日拘留,并处5000元罚款,吊销机动车驾驶证,5年内不得重新取得机动车驾驶证,醉酒驾驶营运机动车的,由公安机关交通管理部门约束至酒醒,吊销机动车驾驶证,依法追究刑事责任,10年内不得重新取得机动车驾驶证,重新取得机动车驾驶证后,不得驾驶营运机动车,饮酒后或者醉酒驾驶机动车发生重大交通事故,构成犯罪的,依法追究刑事责任,并由公安机关交通管理部门吊销机动车驾驶证,终生不得重新取得机动车驾驶证啊。
356,270,sentence1,近年来道县坚持以人民为中心的发展理念,准确把握民政工作面临的新形势、新任务、新机遇,聚焦困难群体,扎实推进民生保障工作,各项工作取得显著成效,切实提升了群众的幸福感、获得感和安全感,聚焦困难群众,抓细抓实社会救助,强化低收入人口动态监测和常态化救助帮扶,有序推进社会救助扩围增效,推动社会救助从保生存向防风险促发展转变,足额发放救助资金,一季度保障城市低保075万人次,发放城市低保资金37053万元,保障农村低保对象554万人次,发放农村低保资金170168万元,救助特困供养人员161万人次,发放特困供养金96688万元,扎实开展寒冬送暖专项救助行动,一季度临时救助困难群众980多人次,发放临时救助资金22536万元,聚焦一老一小抓细抓实社会福利事业,落实儿童福利政策,一季度发放孤儿基本生活费13192万元,发放事实无人抚养儿童基本生活费15475万元。
355,1147,sentence1,根据突发事件应对法,中华人民控规规30法、森林防火条例,湖南省森林防火若干规定等有关法律法规规定,特发布森林防火禁火令,以京沪石岗2024年10月1号至2025年4月30号,经过访问,强县林区及林体边缘50米范围内,32,进货期内燕京带活菌送上,眼睛脸上遭淋,严禁在禁扫范围内扫黄、扫田岗、扫草木灰、扫农作物秸秆,及其他能使用户行为严禁在林区内或另让别人焚烧垃圾,扔掉烟头,烧烤烟群,扫虎去龙烧伤取水、烧喷口,燕京鸡丝、电锯、浪熊、扫纸、浪风、烟花鞭炮、燕京方克、林丹等其他野外用火及1月发森林火灾的过程,严禁其他一切未经批准过意外用户,因防止贫穷害工程建设等特需需要野外用火狗,采取剥皮、在今或其内,对未经疲倦三次在禁火范围内野外用火狗,雨林野居馆铺卖,根据森林防火条例相关条款处罚,造成森林火灾,构成房车一个依法决绝刑事责任。
321,274,sentence1,根据突发事件应对法,中华人民共和国森林法、森林防火条例,湖南省森林防火若干规定等有关法律法规规定,特发布森林防火禁火令,以京沪石岗2024年10月1号至2025年4月30号,以经过访问,强县林区及林体边缘50米范围内伤,进货期内燕京带货均送上,眼睛脸上着灵,严禁在禁扫范围内扫黄,扫天岗、扫草木灰、扫农作物秸秆,及其他能使用户行为严禁在林区内或另让别人焚烧垃圾,扔掉烟头,烧烤烟群扫货去农,烧伤取水、烧朋口、燕京鸡丝、电锯、浪熊、扫纸、浪风、烟花鞭炮、燕京方、孔令丹等其他野外用户及1月发森林火灾的过程,严禁其他一切未经批准过野外用火,应防止贫穷化工程建设等,太需需要意外用火锅,那采取包皮是在进货期内对未经疲倦,三次在进货范围内野外用火狗,野林野居馆铺卖。
316,1144,sentence1,森林火灾报警电话119,宜章县人民政府2024年3月16日,宜章县人民政府禁火令,一政令20241号,为有效预防和遏制森林火灾的发生,保护人民群众生命财产和森林资源安全,根据中华人民共和国森林法,中华人民共和国治安管理处罚法、森林防火条例、湖南省森林防火若干规定,郴州市野外用火管理若干规定等法律法规的规定,结合我县实际,特发布本命令,全县行政区域内的森林防火区和农业生产生活区等易诱发森林火灾的区域,实行全年野外用火监督管理,森林防火期,我县全年为森林防火期,森林高火险期每年9月1日至翌年4月30日,和全年之内任何时段发布森林火险三级以上含三级预警及高温、干旱、大风等森林防火紧急状态的森林高火险天气,为森林高火险期,森林防火区和农业生产生活区!
316,1148,sentence1,根据突发事件应该发君华人民控规规30发,森林防火条例,湖南省森林防火若干规定等有关法律法规规定,特发布森林防火禁火令,以京沪石岗2024年10月1号至2021年4月30号,经过访问,强县令曲及令体边缘50米范围内伤,禁火期内燕京带活菌烧伤,眼睛脸上遭淋,严禁在禁烧范围内扫黄、扫田岗、扫草木灰、扫农作物秸秆,及其他能使用户行为严禁在林区内或另让别人焚烧垃圾,扔掉烟头,烧烤烟群扫货去龙,烧伤取水烧喷口,燕京鸡丝、电锯、浪熊、扫纸、浪风、烟花鞭炮、燕京方、孔令丹等其他野外用户,即1月发森林火灾的回程,严禁其他一切未经批准的野外用火,因防止贫穷害工程建设等特需需要野外用火的,采取剥皮、在今或其内,对未经疲倦三次在禁火范围内野外用火狗,雨林野居馆铺卖。
315,254,sentence1,森林火灾报警电话119,宜章县人民政府2024年3月16日,宜章县人民政府禁火令,一政令20241号,为有效预防和遏制森林火灾的发生,保护人民群众生命财产和森林资源安全,根据中华人民共和国森林法、中华人民共和国治安管理处罚法、森林防火条例、湖南省森林防火若干规定,郴州市野外用火管理若干规定等法律法规的规定,结合我县实际,特发布本命令,全县行政区域内的森林防火区和农业生产生活区等易诱发森林火灾的区域,实行全年野外用火监督管理,森林防火期,我县全年为森林防火期,森林高火险期每年9月1日至翌年4月30日,和全年之内任何时段发布森林火险三级以上含三级预警及高温、干旱大风等森林防火紧急状态的森林高火险天气,为森林高火险期,森林防火区和农业生产生活区!
315,1147,sentence1,根据突发事件应对法,中华人民共和国森林法、森林防火条例,湖南省森林防火陆港规定等有关法律法规规定,特发布森林防火禁火令,以京沪石岗2024年10月1号至2025年4月30号,经过防卫,强县林区及林体边缘50米范围内伤,进货期内燕京带货均送上,燕京恋上赵丽颖,燕京在禁售范围内扫黄扫田岗、扫草木灰、扫农作物秸秆,及其他能使用户行为严禁在林区内或另让别人焚烧垃圾,扔掉烟头,烧烤烟群扫货去农烧伤取水烧盆口、燕京鸡丝、电锯、浪熊、扫纸、浪风、烟花鞭炮、燕京方、孔令丹等其他野外用户及1月发森林火灾的回头,严禁其他一切未经批准的野外用火,应防止贫穷化工程建设等,太需要野外用火锅,那采取包皮在进货期内,对未经批卷三次在禁火范围内野外用火狗与林野居广铺卖。

332390
4.结构化json-Ai标注/统计/segmentation_results_from_7_retried.json

File diff suppressed because it is too large Load Diff

BIN
4.结构化json-Ai标注/统计/segmentation_results_from_7_retried_pie_chart.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 211 KiB

1431
5.AI标注-model_trian/LoRa+NN/失败案例/train-robert-wwm-ext.py

File diff suppressed because it is too large Load Diff

1188
5.AI标注-model_trian/全参+NN/train-robert-large.py

File diff suppressed because it is too large Load Diff

1431
5.AI标注-model_trian/全参+NN/train-robert-wwm-ext-new.py

File diff suppressed because it is too large Load Diff

1431
5.AI标注-model_trian/全参+NN/train-robert-wwm-ext.py

File diff suppressed because it is too large Load Diff

1686
5.AI标注-model_trian/全参微调/FreeLB扰动训练/Bert-train_FreeLB.py

File diff suppressed because it is too large Load Diff

957
5.AI标注-model_trian/全参微调/无验证集训练/Bert-train.py

@ -0,0 +1,957 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from transformers import (
BertTokenizer,
BertForSequenceClassification,
BertModel,
BertConfig,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
TrainerCallback
)
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import logging
import os
from datetime import datetime
import math
from collections import defaultdict, Counter
# 禁用wandb和其他第三方报告工具
os.environ["WANDB_DISABLED"] = "true"
os.environ["COMET_MODE"] = "disabled"
os.environ["NEPTUNE_MODE"] = "disabled"
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def check_gpu_availability():
"""检查GPU可用性"""
if not torch.cuda.is_available():
raise RuntimeError("❌ GPU不可用!此脚本需要GPU支持。")
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
logger.info(f"✅ GPU检查通过!")
logger.info(f" 🔹 可用GPU数量: {gpu_count}")
logger.info(f" 🔹 GPU型号: {gpu_name}")
logger.info(f" 🔹 GPU内存: {gpu_memory:.1f} GB")
# V100优化设置
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
return True, gpu_memory
class LossTracker(TrainerCallback):
"""损失跟踪回调类"""
def __init__(self):
self.train_losses = []
self.eval_losses = []
self.train_steps = []
self.eval_steps = []
self.current_epoch = 0
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
if 'loss' in logs:
self.train_losses.append(logs['loss'])
self.train_steps.append(state.global_step)
if 'eval_loss' in logs:
self.eval_losses.append(logs['eval_loss'])
self.eval_steps.append(state.global_step)
def on_epoch_end(self, args, state, control, **kwargs):
self.current_epoch = state.epoch
class ConfusionMatrixCallback(TrainerCallback):
"""混淆矩阵生成回调"""
def __init__(self, eval_dataset, tokenizer, output_dir, epochs_interval=20):
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.output_dir = output_dir
self.epochs_interval = epochs_interval
self.confusion_matrices = {}
def on_epoch_end(self, args, state, control, model=None, **kwargs):
current_epoch = int(state.epoch)
if current_epoch % self.epochs_interval == 0 or current_epoch == args.num_train_epochs:
logger.info(f"📊 Generating confusion matrix for epoch {current_epoch}...")
model.eval()
predictions = []
true_labels = []
device = next(model.parameters()).device
with torch.no_grad():
for i in range(len(self.eval_dataset)):
item = self.eval_dataset[i]
input_ids = item['input_ids'].unsqueeze(0).to(device)
attention_mask = item['attention_mask'].unsqueeze(0).to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(outputs['logits'], dim=-1).cpu().item()
predictions.append(pred)
true_labels.append(item['labels'].item())
cm = confusion_matrix(true_labels, predictions)
self.confusion_matrices[current_epoch] = cm
self.save_confusion_matrix(cm, current_epoch)
model.train()
def save_confusion_matrix(self, cm, epoch):
"""保存混淆矩阵图"""
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'],
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'])
plt.title(f'Confusion Matrix - Epoch {epoch}')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
accuracy = np.trace(cm) / np.sum(cm)
plt.text(0.5, -0.15, f'Accuracy: {accuracy:.4f}',
ha='center', transform=plt.gca().transAxes)
plt.tight_layout()
save_path = os.path.join(self.output_dir, f'confusion_matrix_epoch_{epoch}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f" 💾 Confusion matrix saved: {save_path}")
def plot_training_curves(loss_tracker, output_dir):
"""绘制训练损失曲线"""
plt.figure(figsize=(12, 8))
if loss_tracker.train_losses:
plt.subplot(2, 1, 1)
plt.plot(loss_tracker.train_steps, loss_tracker.train_losses,
'b-', label='Training Loss', linewidth=2, alpha=0.8)
plt.title('Training Loss Curve', fontsize=14, fontweight='bold')
plt.xlabel('Training Steps')
plt.ylabel('Loss Value')
plt.legend()
plt.grid(True, alpha=0.3)
if len(loss_tracker.train_losses) > 10:
z = np.polyfit(loss_tracker.train_steps, loss_tracker.train_losses, 1)
p = np.poly1d(z)
plt.plot(loss_tracker.train_steps, p(loss_tracker.train_steps),
'r--', alpha=0.6, label='Trend Line')
plt.legend()
if loss_tracker.eval_losses:
plt.subplot(2, 1, 2)
plt.plot(loss_tracker.eval_steps, loss_tracker.eval_losses,
'g-', label='Validation Loss', linewidth=2, alpha=0.8)
plt.title('Validation Loss Curve', fontsize=14, fontweight='bold')
plt.xlabel('Training Steps')
plt.ylabel('Loss Value')
plt.legend()
plt.grid(True, alpha=0.3)
if loss_tracker.train_losses and loss_tracker.eval_losses:
plt.figure(figsize=(12, 6))
min_len = min(len(loss_tracker.train_losses), len(loss_tracker.eval_losses))
train_steps_aligned = loss_tracker.train_steps[:min_len]
train_losses_aligned = loss_tracker.train_losses[:min_len]
eval_steps_aligned = loss_tracker.eval_steps[:min_len]
eval_losses_aligned = loss_tracker.eval_losses[:min_len]
plt.plot(train_steps_aligned, train_losses_aligned,
'b-', label='Training Loss', linewidth=2, alpha=0.8)
plt.plot(eval_steps_aligned, eval_losses_aligned,
'r-', label='Validation Loss', linewidth=2, alpha=0.8)
plt.title('Training vs Validation Loss Comparison', fontsize=16, fontweight='bold')
plt.xlabel('Training Steps', fontsize=12)
plt.ylabel('Loss Value', fontsize=12)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
if len(eval_losses_aligned) > 20:
recent_train = np.mean(train_losses_aligned[-10:])
recent_eval = np.mean(eval_losses_aligned[-10:])
if recent_eval > recent_train * 1.2:
plt.text(0.7, 0.9, ' Possible Overfitting', transform=plt.gca().transAxes,
bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
plt.tight_layout()
compare_path = os.path.join(output_dir, 'loss_comparison_curves.png')
plt.savefig(compare_path, dpi=300, bbox_inches='tight')
logger.info(f"📈 Training comparison curves saved: {compare_path}")
plt.tight_layout()
curves_path = os.path.join(output_dir, 'training_curves.png')
plt.savefig(curves_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"📈 Training curves saved: {curves_path}")
class SentencePairDataset(Dataset):
"""句子对数据集类(支持加权采样)"""
def __init__(self, data, tokenizer, max_length=512):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.valid_data = [item for item in data if item['label'] in [0, 1]]
logger.info(f"原始数据: {len(data)} 条,有效数据: {len(self.valid_data)}")
self.sentence1_list = [item['sentence1'] for item in self.valid_data]
self.sentence2_list = [item['sentence2'] for item in self.valid_data]
self.labels = [item['label'] for item in self.valid_data]
self.class_counts = Counter(self.labels)
self.class_weights = self._compute_class_weights()
self.sample_weights = self._compute_sample_weights()
def _compute_class_weights(self):
"""计算类别权重"""
total_samples = len(self.labels)
class_weights = {}
for label in [0, 1]:
count = self.class_counts[label]
class_weights[label] = total_samples / (2 * count)
return class_weights
def _compute_sample_weights(self):
"""计算每个样本的权重"""
sample_weights = []
for label in self.labels:
sample_weights.append(self.class_weights[label])
return torch.tensor(sample_weights, dtype=torch.float)
def get_weighted_sampler(self):
"""返回加权随机采样器"""
return WeightedRandomSampler(
weights=self.sample_weights,
num_samples=len(self.sample_weights),
replacement=True
)
def __len__(self):
return len(self.valid_data)
def __getitem__(self, idx):
sentence1 = str(self.sentence1_list[idx])
sentence2 = str(self.sentence2_list[idx])
label = self.labels[idx]
encoding = self.tokenizer(
sentence1,
sentence2,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
def load_training_data(train_file):
"""加载训练数据"""
try:
with open(train_file, 'r', encoding='utf-8') as f:
train_data = json.load(f)
logger.info(f"成功加载训练数据: {len(train_data)} 条记录")
return train_data
except Exception as e:
logger.error(f"加载训练数据失败: {str(e)}")
return None
def analyze_data_distribution(data):
"""分析数据分布并计算优化的Focal Loss参数"""
valid_data = [item for item in data if item['label'] in [0, 1]]
label_counts = {}
for item in valid_data:
label = item['label']
label_counts[label] = label_counts.get(label, 0) + 1
total_samples = len(valid_data)
logger.info("=== 训练数据分布分析 ===")
logger.info(f"总有效记录数: {total_samples}")
class_proportions = {}
alpha_weights = []
for label in [0, 1]:
count = label_counts.get(label, 0)
proportion = count / total_samples
class_proportions[label] = proportion
label_name = "同段落" if label == 0 else "不同段落"
logger.info(f"标签 {label} ({label_name}): {count} 条 ({proportion * 100:.2f}%)")
minority_ratio = min(class_proportions.values())
imbalance_ratio = max(class_proportions.values()) / minority_ratio
logger.info(f"\n📊 数据不平衡分析:")
logger.info(f" 🔹 少数类比例: {minority_ratio * 100:.2f}%")
logger.info(f" 🔹 不平衡比率: {imbalance_ratio:.2f}:1")
if imbalance_ratio > 5:
alpha_weights = [0.1, 0.9]
logger.info(" 🎯 使用激进的alpha权重设置")
else:
alpha_weights = [1.0 - class_proportions[0], 1.0 - class_proportions[1]]
if imbalance_ratio > 6:
recommended_gamma = 3.5
logger.info(" 严重不平衡 - 使用Gamma=3.5强化聚焦")
elif imbalance_ratio > 4:
recommended_gamma = 3.0
logger.info(" 中度偏严重不平衡 - 使用Gamma=3.0")
else:
recommended_gamma = 2.5
logger.info(f"\n🎯 优化的Focal Loss参数设置:")
logger.info(f" 🔹 Alpha权重: [多数类={alpha_weights[0]:.3f}, 少数类={alpha_weights[1]:.3f}]")
logger.info(f" 🔹 优化Gamma: {recommended_gamma} (增强难样本聚焦)")
logger.info(f" 🔹 公式: FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)")
logger.info(f" 🔹 加权采样: 启用WeightedRandomSampler")
return label_counts, alpha_weights, recommended_gamma
def compute_metrics(eval_pred):
"""计算评估指标"""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
accuracy = accuracy_score(labels, predictions)
return {
'accuracy': accuracy,
}
class FocalLoss(nn.Module):
"""优化的Focal Loss用于处理类别不平衡问题"""
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力机制"""
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class RoBERTaWithScaledAttentionAndFocalLoss(nn.Module):
"""带缩放点积注意力和优化Focal Loss的RoBERTa模型"""
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0):
super(RoBERTaWithScaledAttentionAndFocalLoss, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
"""初始化新增层的权重"""
nn.init.normal_(self.classifier.weight, std=0.02)
nn.init.zeros_(self.classifier.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
logits = self.classifier(cls_output)
loss = None
if labels is not None:
loss = self.focal_loss(logits, labels)
return {
'loss': loss,
'logits': logits,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def save_pretrained(self, save_directory):
"""保存模型"""
os.makedirs(save_directory, exist_ok=True)
model_to_save = self.module if hasattr(self, 'module') else self
torch.save(model_to_save.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
config_dict = {
'model_type': 'RoBERTaWithScaledAttentionAndFocalLoss',
'base_model': 'chinese-roberta-wwm-ext',
'num_labels': self.config.num_labels,
'hidden_size': self.config.hidden_size,
'focal_alpha': self.focal_alpha.tolist() if self.focal_alpha is not None else None,
'focal_gamma': self.focal_gamma,
'has_scaled_attention': True,
'has_focal_loss': True,
'optimization_level': 'high_priority_v100'
}
with open(os.path.join(save_directory, 'config.json'), 'w', encoding='utf-8') as f:
json.dump(config_dict, f, ensure_ascii=False, indent=2)
class WeightedTrainer(Trainer):
"""自定义Trainer支持WeightedRandomSampler"""
def __init__(self, weighted_sampler=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weighted_sampler = weighted_sampler
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
if self.weighted_sampler is not None:
train_sampler = self.weighted_sampler
else:
train_sampler = self._get_train_sampler()
return DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
def train_roberta_model(train_data,
model_path="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model",
output_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train",
checkpoint_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result"):
"""训练优化的RoBERTa模型(V100 48GB优化版)"""
gpu_available, gpu_memory = check_gpu_availability()
device = torch.device('cuda')
logger.info(f"🚀 使用GPU设备: {device}")
# 数据分布分析和优化的Focal Loss参数计算
label_distribution, alpha_weights, recommended_gamma = analyze_data_distribution(train_data)
alpha_tensor = torch.tensor(alpha_weights, dtype=torch.float).to(device)
logger.info(f"📥 加载Chinese-RoBERTa-WWM-Ext模型: {model_path}")
tokenizer = BertTokenizer.from_pretrained(model_path)
model = RoBERTaWithScaledAttentionAndFocalLoss(
model_path=model_path,
num_labels=2,
dropout=0.1,
focal_alpha=alpha_tensor,
focal_gamma=recommended_gamma
)
model = model.to(device)
logger.info(f"✅ 模型已加载到GPU: {next(model.parameters()).device}")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"📊 模型参数统计:")
logger.info(f" 🔹 总参数量: {total_params:,}")
logger.info(f" 🔹 可训练参数: {trainable_params:,}")
logger.info("📋 准备训练数据集和加权采样器...")
train_dataset = SentencePairDataset(train_data, tokenizer, max_length=512)
weighted_sampler = train_dataset.get_weighted_sampler()
logger.info(f" 🔹 训练集大小: {len(train_dataset)}")
logger.info(f" 🔹 类别权重: {train_dataset.class_weights}")
# V100 48GB优化配置
batch_size = 16 # V100可以使用更大的batch size
gradient_accumulation = 2
max_grad_norm = 1.0
fp16 = True
dataloader_num_workers = 4
effective_batch_size = batch_size * gradient_accumulation
initial_learning_rate = 2e-5
warmup_ratio = 0.15
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=checkpoint_dir, # checkpoints保存到指定目录
num_train_epochs=100,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation,
eval_strategy="no",
save_strategy="epoch",
save_steps=20,
logging_strategy="steps",
logging_steps=50,
warmup_ratio=warmup_ratio,
weight_decay=0.01,
learning_rate=initial_learning_rate,
load_best_model_at_end=False,
remove_unused_columns=False,
dataloader_pin_memory=True,
fp16=fp16,
dataloader_num_workers=dataloader_num_workers,
group_by_length=True,
report_to=[],
adam_epsilon=1e-8,
max_grad_norm=max_grad_norm,
save_total_limit=5,
skip_memory_metrics=True,
disable_tqdm=False,
lr_scheduler_type="cosine",
warmup_steps=0,
)
logger.info(f"🎯 V100 48GB优化的训练参数:")
logger.info(f" 🔹 训练轮数: {training_args.num_train_epochs}")
logger.info(f" 🔹 批次大小: {batch_size}")
logger.info(f" 🔹 梯度累积: {gradient_accumulation}")
logger.info(f" 🔹 有效批次大小: {effective_batch_size}")
logger.info(f" 🔹 学习率: {training_args.learning_rate}")
logger.info(f" 🔹 预热比例: {warmup_ratio}")
logger.info(f" 🔹 序列长度: 512")
logger.info(f" 🔹 混合精度: {fp16}")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
loss_tracker = LossTracker()
confusion_matrix_callback = ConfusionMatrixCallback(
eval_dataset=train_dataset,
tokenizer=tokenizer,
output_dir=checkpoint_dir, # 混淆矩阵保存到指定目录
epochs_interval=20
)
trainer = WeightedTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[loss_tracker, confusion_matrix_callback],
weighted_sampler=weighted_sampler
)
logger.info("🏃 开始100 epoch优化训练...")
logger.info("🎯 高优先级优化内容:")
logger.info(" ✅ Focal Loss Gamma: 3.0-3.5")
logger.info(" ✅ Alpha权重: [0.1, 0.9]")
logger.info(" ✅ 学习率: 2e-5")
logger.info(" ✅ 预热比例: 15%")
logger.info(" ✅ WeightedRandomSampler")
logger.info(" ✅ 余弦退火学习率调度")
start_time = datetime.now()
try:
trainer.train()
except RuntimeError as e:
if "out of memory" in str(e).lower():
logger.error("❌ GPU内存不足!")
logger.error("💡 建议减小批次大小")
raise
else:
raise
end_time = datetime.now()
training_duration = end_time - start_time
logger.info(f"🎉 100 epoch优化训练完成! 耗时: {training_duration}")
logger.info("📈 生成训练可视化图表...")
plot_training_curves(loss_tracker, checkpoint_dir)
logger.info(f"💾 保存优化模型到: {output_dir}")
# 保存到指定的模型输出目录
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# 保存损失历史到checkpoints目录
loss_history = {
'train_losses': loss_tracker.train_losses,
'train_steps': loss_tracker.train_steps,
}
with open(os.path.join(checkpoint_dir, 'loss_history.json'), 'w', encoding='utf-8') as f:
json.dump(loss_history, f, ensure_ascii=False, indent=2)
# 保存混淆矩阵历史到checkpoints目录
cm_history = {epoch: cm.tolist() for epoch, cm in confusion_matrix_callback.confusion_matrices.items()}
with open(os.path.join(checkpoint_dir, 'confusion_matrix_history.json'), 'w', encoding='utf-8') as f:
json.dump(cm_history, f, ensure_ascii=False, indent=2)
# 保存详细的训练信息到checkpoints目录
training_info = {
"model_name": model_path,
"model_type": "Chinese-RoBERTa-WWM-Ext with Optimized Focal Loss and Weighted Sampling",
"optimization_level": "high_priority_v100_48gb",
"training_duration": str(training_duration),
"num_train_samples": len(train_dataset),
"label_distribution": label_distribution,
"data_imbalance": {
"class_0_count": label_distribution.get(0, 0),
"class_1_count": label_distribution.get(1, 0),
"class_0_ratio": label_distribution.get(0, 0) / len(train_dataset),
"class_1_ratio": label_distribution.get(1, 0) / len(train_dataset),
"imbalance_ratio": label_distribution.get(0, 1) / label_distribution.get(1, 1)
},
"optimized_focal_loss_params": {
"alpha_weights": alpha_weights,
"gamma": recommended_gamma,
"formula": "FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)",
"optimization": "aggressive_minority_class_focus"
},
"weighted_sampling": {
"enabled": True,
"class_weights": train_dataset.class_weights,
"sampler_type": "WeightedRandomSampler"
},
"optimized_learning_strategy": {
"initial_learning_rate": initial_learning_rate,
"warmup_ratio": warmup_ratio,
"lr_scheduler": "cosine",
"improvement": "optimized_for_v100"
},
"gpu_optimization": {
"gpu_name": torch.cuda.get_device_name(0),
"gpu_memory_gb": gpu_memory,
"optimization_target": "V100_48GB",
"effective_batch_size": effective_batch_size,
"sequence_length": 512,
"batch_size_optimization": "v100_optimized"
},
"training_args": {
"num_train_epochs": training_args.num_train_epochs,
"per_device_train_batch_size": training_args.per_device_train_batch_size,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"learning_rate": training_args.learning_rate,
"warmup_ratio": training_args.warmup_ratio,
"weight_decay": training_args.weight_decay,
"fp16": training_args.fp16,
"lr_scheduler_type": training_args.lr_scheduler_type
},
"model_parameters": {
"total_params": total_params,
"trainable_params": trainable_params,
},
"paths": {
"model_input_path": model_path,
"model_output_path": output_dir,
"checkpoint_output_path": checkpoint_dir,
"data_path": "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data"
},
"high_priority_optimizations": [
"Focal Loss Gamma increased to 3.0-3.5",
"Alpha weights set to [0.1, 0.9] for aggressive minority class focus",
"Learning rate optimized for V100: 2e-5",
"Warmup ratio increased to 15%",
"WeightedRandomSampler for balanced class sampling",
"Cosine annealing learning rate scheduler",
"V100 48GB optimized batch size: 16",
"Full sequence length: 512 tokens"
],
"visualization_files": {
"training_curves": "training_curves.png",
"confusion_matrices": [f"confusion_matrix_epoch_{i}.png" for i in range(20, 101, 20)] + [
"confusion_matrix_epoch_100.png"],
"loss_history": "loss_history.json",
"confusion_matrix_history": "confusion_matrix_history.json"
},
"training_completed": end_time.isoformat()
}
with open(os.path.join(checkpoint_dir, 'training_info.json'), 'w', encoding='utf-8') as f:
json.dump(training_info, f, ensure_ascii=False, indent=2)
# 同时在模型目录保存一份配置信息
with open(os.path.join(output_dir, 'training_summary.json'), 'w', encoding='utf-8') as f:
json.dump(training_info, f, ensure_ascii=False, indent=2)
logger.info("📋 优化训练信息已保存")
return trainer, model, tokenizer, loss_tracker, confusion_matrix_callback
def main():
"""主函数"""
logger.info("=" * 120)
logger.info("🚀 Chinese-RoBERTa-WWM-Ext V100 48GB高优化训练")
logger.info("=" * 120)
# 配置路径
train_file = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data/train_dataset.json"
model_path = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model"
output_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train"
checkpoint_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result"
# 确保所有输出目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"📁 确保输出目录存在:")
logger.info(f" 🔹 模型输出: {output_dir}")
logger.info(f" 🔹 训练记录: {checkpoint_dir}")
# 确认第三方报告工具已禁用
logger.info("🚫 确认第三方报告工具状态:")
logger.info(f" 🔹 WANDB_DISABLED: {os.environ.get('WANDB_DISABLED', 'not set')}")
logger.info(f" 🔹 COMET_MODE: {os.environ.get('COMET_MODE', 'not set')}")
logger.info(f" 🔹 NEPTUNE_MODE: {os.environ.get('NEPTUNE_MODE', 'not set')}")
logger.info(f" ✅ 所有第三方报告工具已禁用")
logger.info(f"\n📋 V100 48GB优化配置:")
logger.info(f" 🔹 训练数据: {train_file}")
logger.info(f" 🔹 基础模型: {model_path}")
logger.info(f" 🔹 模型类型: Chinese-RoBERTa-WWM-Ext")
logger.info(f" 🔹 优化等级: V100 48GB高性能优化")
logger.info(f" 🔹 目标: 处理严重数据不平衡问题")
logger.info(f" 🔹 核心优化:")
logger.info(f" • Focal Loss Gamma: 3.0+ (增强难样本聚焦)")
logger.info(f" • Alpha权重: [0.1, 0.9] (激进的少数类关注)")
logger.info(f" • 学习率: 2e-5 (V100优化)")
logger.info(f" • 批次大小: 16 (V100大显存优化)")
logger.info(f" • 序列长度: 512 (完整长度)")
logger.info(f" • WeightedRandomSampler (平衡采样)")
logger.info(f" • 余弦退火学习率调度")
logger.info(f" 🔹 训练轮数: 100 epochs")
logger.info(f" 🔹 模型输出: {output_dir}")
logger.info(f" 🔹 训练记录: {checkpoint_dir}")
# 加载训练数据
train_data = load_training_data(train_file)
if train_data is None:
logger.error("❌ 无法加载训练数据,程序退出")
return
try:
# 训练优化模型
trainer, model, tokenizer, loss_tracker, cm_callback = train_roberta_model(
train_data,
model_path=model_path,
output_dir=output_dir,
checkpoint_dir=checkpoint_dir
)
logger.info("=" * 120)
logger.info("🎉 V100 48GB高优化训练完成!")
logger.info("=" * 120)
logger.info(f"📁 文件输出位置:")
logger.info(f" 🔹 训练好的模型: {output_dir}")
logger.info(f" 🔹 训练记录和图表: {checkpoint_dir}")
logger.info("📄 生成的文件:")
logger.info(" 模型文件 (model_train目录):")
logger.info(" • pytorch_model.bin - 优化训练的模型权重")
logger.info(" • config.json - 优化模型配置")
logger.info(" • tokenizer配置文件")
logger.info(" • training_summary.json - 训练摘要")
logger.info(" 训练记录 (ouput_result目录):")
logger.info(" • training_info.json - 详细优化训练信息")
logger.info(" • loss_history.json - 损失历史数据")
logger.info(" • confusion_matrix_history.json - 混淆矩阵历史")
logger.info(" • training_curves.png - 训练损失曲线")
logger.info(" • confusion_matrix_epoch_X.png - 各epoch混淆矩阵")
logger.info(" • checkpoint-* - 训练检查点")
logger.info("🔥 V100 48GB优化特性:")
logger.info(" ✅ Chinese-RoBERTa-WWM-Ext 基础模型")
logger.info(" ✅ 激进的Focal Loss参数 (Gamma=3.0+, Alpha=[0.1,0.9])")
logger.info(" ✅ V100优化学习率: 2e-5")
logger.info(" ✅ 大批次训练: 16 (有效批次: 32)")
logger.info(" ✅ 完整序列长度: 512 tokens")
logger.info(" ✅ WeightedRandomSampler 平衡采样")
logger.info(" ✅ 余弦退火学习率调度")
logger.info(" ✅ 缩放点积注意力机制")
logger.info(" ✅ 100 epochs长时间训练")
logger.info(" ✅ 完整可视化监控")
logger.info("🎯 针对数据不平衡的专项优化:")
logger.info(" ⚡ 少数类样本权重提升9倍")
logger.info(" ⚡ 难分类样本聚焦增强50%")
logger.info(" ⚡ V100大显存充分利用")
logger.info(" ⚡ 类别平衡采样确保训练公平性")
logger.info(" ⚡ 预期少数类F1分数提升20-35%")
# 显示完整保存路径列表
logger.info(f"\n📂 文件保存详情:")
logger.info(f"📋 模型文件 ({output_dir}):")
try:
for file in os.listdir(output_dir):
file_path = os.path.join(output_dir, file)
if os.path.isfile(file_path):
file_size = os.path.getsize(file_path) / (1024 * 1024)
logger.info(f" 📄 {file} ({file_size:.2f} MB)")
except Exception as e:
logger.warning(f" 无法列出模型文件: {str(e)}")
logger.info(f"📋 训练记录 ({checkpoint_dir}):")
try:
files = os.listdir(checkpoint_dir)
# 按类型分组显示
png_files = [f for f in files if f.endswith('.png')]
json_files = [f for f in files if f.endswith('.json')]
checkpoint_dirs = [f for f in files if f.startswith('checkpoint-')]
other_files = [f for f in files if f not in png_files + json_files + checkpoint_dirs]
if json_files:
logger.info(" JSON配置文件:")
for file in sorted(json_files):
file_path = os.path.join(checkpoint_dir, file)
file_size = os.path.getsize(file_path) / 1024
logger.info(f" 📄 {file} ({file_size:.1f} KB)")
if png_files:
logger.info(" 可视化图表:")
for file in sorted(png_files):
file_path = os.path.join(checkpoint_dir, file)
file_size = os.path.getsize(file_path) / 1024
logger.info(f" 📊 {file} ({file_size:.1f} KB)")
if checkpoint_dirs:
logger.info(" 训练检查点:")
for dir_name in sorted(checkpoint_dirs):
logger.info(f" 📁 {dir_name}/")
if other_files:
logger.info(" 其他文件:")
for file in sorted(other_files):
file_path = os.path.join(checkpoint_dir, file)
if os.path.isfile(file_path):
file_size = os.path.getsize(file_path) / (1024 * 1024)
logger.info(f" 📄 {file} ({file_size:.2f} MB)")
except Exception as e:
logger.warning(f" 无法列出训练记录: {str(e)}")
logger.info("\n🎯 训练完成,可以开始评估模型性能!")
except Exception as e:
logger.error(f"❌ V100优化训练过程中出现错误: {str(e)}")
import traceback
traceback.print_exc()
raise
if __name__ == "__main__":
main()

990
5.AI标注-model_trian/全参微调/无验证集训练/train-continue.py

@ -0,0 +1,990 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from transformers import (
BertTokenizer,
BertForSequenceClassification,
BertModel,
BertConfig,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
TrainerCallback
)
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import logging
import os
from datetime import datetime
import math
from collections import defaultdict, Counter
# 禁用wandb和其他第三方报告工具
os.environ["WANDB_DISABLED"] = "true"
os.environ["COMET_MODE"] = "disabled"
os.environ["NEPTUNE_MODE"] = "disabled"
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def check_gpu_availability():
"""检查GPU可用性"""
if not torch.cuda.is_available():
raise RuntimeError("❌ GPU不可用!此脚本需要GPU支持。")
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
logger.info(f"✅ GPU检查通过!")
logger.info(f" 🔹 可用GPU数量: {gpu_count}")
logger.info(f" 🔹 GPU型号: {gpu_name}")
logger.info(f" 🔹 GPU内存: {gpu_memory:.1f} GB")
# V100优化设置
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
return True, gpu_memory
class LossTracker(TrainerCallback):
"""损失跟踪回调类"""
def __init__(self):
self.train_losses = []
self.eval_losses = []
self.train_steps = []
self.eval_steps = []
self.current_epoch = 0
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
if 'loss' in logs:
self.train_losses.append(logs['loss'])
self.train_steps.append(state.global_step)
if 'eval_loss' in logs:
self.eval_losses.append(logs['eval_loss'])
self.eval_steps.append(state.global_step)
def on_epoch_end(self, args, state, control, **kwargs):
self.current_epoch = state.epoch
class ConfusionMatrixCallback(TrainerCallback):
"""混淆矩阵生成回调"""
def __init__(self, eval_dataset, tokenizer, output_dir, epochs_interval=20):
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.output_dir = output_dir
self.epochs_interval = epochs_interval
self.confusion_matrices = {}
def on_epoch_end(self, args, state, control, model=None, **kwargs):
current_epoch = int(state.epoch)
if current_epoch % self.epochs_interval == 0 or current_epoch == args.num_train_epochs:
logger.info(f"📊 Generating confusion matrix for epoch {current_epoch}...")
model.eval()
predictions = []
true_labels = []
device = next(model.parameters()).device
with torch.no_grad():
for i in range(len(self.eval_dataset)):
item = self.eval_dataset[i]
input_ids = item['input_ids'].unsqueeze(0).to(device)
attention_mask = item['attention_mask'].unsqueeze(0).to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(outputs['logits'], dim=-1).cpu().item()
predictions.append(pred)
true_labels.append(item['labels'].item())
cm = confusion_matrix(true_labels, predictions)
self.confusion_matrices[current_epoch] = cm
self.save_confusion_matrix(cm, current_epoch)
model.train()
def save_confusion_matrix(self, cm, epoch):
"""保存混淆矩阵图"""
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'],
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'])
plt.title(f'Confusion Matrix - Epoch {epoch}')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
accuracy = np.trace(cm) / np.sum(cm)
plt.text(0.5, -0.15, f'Accuracy: {accuracy:.4f}',
ha='center', transform=plt.gca().transAxes)
plt.tight_layout()
save_path = os.path.join(self.output_dir, f'confusion_matrix_epoch_{epoch}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f" 💾 Confusion matrix saved: {save_path}")
def plot_training_curves(loss_tracker, output_dir):
"""绘制训练损失曲线"""
plt.figure(figsize=(12, 8))
if loss_tracker.train_losses:
plt.subplot(2, 1, 1)
plt.plot(loss_tracker.train_steps, loss_tracker.train_losses,
'b-', label='Training Loss', linewidth=2, alpha=0.8)
plt.title('Training Loss Curve', fontsize=14, fontweight='bold')
plt.xlabel('Training Steps')
plt.ylabel('Loss Value')
plt.legend()
plt.grid(True, alpha=0.3)
if len(loss_tracker.train_losses) > 10:
z = np.polyfit(loss_tracker.train_steps, loss_tracker.train_losses, 1)
p = np.poly1d(z)
plt.plot(loss_tracker.train_steps, p(loss_tracker.train_steps),
'r--', alpha=0.6, label='Trend Line')
plt.legend()
if loss_tracker.eval_losses:
plt.subplot(2, 1, 2)
plt.plot(loss_tracker.eval_steps, loss_tracker.eval_losses,
'g-', label='Validation Loss', linewidth=2, alpha=0.8)
plt.title('Validation Loss Curve', fontsize=14, fontweight='bold')
plt.xlabel('Training Steps')
plt.ylabel('Loss Value')
plt.legend()
plt.grid(True, alpha=0.3)
if loss_tracker.train_losses and loss_tracker.eval_losses:
plt.figure(figsize=(12, 6))
min_len = min(len(loss_tracker.train_losses), len(loss_tracker.eval_losses))
train_steps_aligned = loss_tracker.train_steps[:min_len]
train_losses_aligned = loss_tracker.train_losses[:min_len]
eval_steps_aligned = loss_tracker.eval_steps[:min_len]
eval_losses_aligned = loss_tracker.eval_losses[:min_len]
plt.plot(train_steps_aligned, train_losses_aligned,
'b-', label='Training Loss', linewidth=2, alpha=0.8)
plt.plot(eval_steps_aligned, eval_losses_aligned,
'r-', label='Validation Loss', linewidth=2, alpha=0.8)
plt.title('Training vs Validation Loss Comparison', fontsize=16, fontweight='bold')
plt.xlabel('Training Steps', fontsize=12)
plt.ylabel('Loss Value', fontsize=12)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
if len(eval_losses_aligned) > 20:
recent_train = np.mean(train_losses_aligned[-10:])
recent_eval = np.mean(eval_losses_aligned[-10:])
if recent_eval > recent_train * 1.2:
plt.text(0.7, 0.9, ' Possible Overfitting', transform=plt.gca().transAxes,
bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
plt.tight_layout()
compare_path = os.path.join(output_dir, 'loss_comparison_curves.png')
plt.savefig(compare_path, dpi=300, bbox_inches='tight')
logger.info(f"📈 Training comparison curves saved: {compare_path}")
plt.tight_layout()
curves_path = os.path.join(output_dir, 'training_curves.png')
plt.savefig(curves_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"📈 Training curves saved: {curves_path}")
class SentencePairDataset(Dataset):
"""句子对数据集类(支持加权采样)"""
def __init__(self, data, tokenizer, max_length=512):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.valid_data = [item for item in data if item['label'] in [0, 1]]
logger.info(f"原始数据: {len(data)} 条,有效数据: {len(self.valid_data)}")
self.sentence1_list = [item['sentence1'] for item in self.valid_data]
self.sentence2_list = [item['sentence2'] for item in self.valid_data]
self.labels = [item['label'] for item in self.valid_data]
self.class_counts = Counter(self.labels)
self.class_weights = self._compute_class_weights()
self.sample_weights = self._compute_sample_weights()
def _compute_class_weights(self):
"""计算类别权重"""
total_samples = len(self.labels)
class_weights = {}
for label in [0, 1]:
count = self.class_counts[label]
class_weights[label] = total_samples / (2 * count)
return class_weights
def _compute_sample_weights(self):
"""计算每个样本的权重"""
sample_weights = []
for label in self.labels:
sample_weights.append(self.class_weights[label])
return torch.tensor(sample_weights, dtype=torch.float)
def get_weighted_sampler(self):
"""返回加权随机采样器"""
return WeightedRandomSampler(
weights=self.sample_weights,
num_samples=len(self.sample_weights),
replacement=True
)
def __len__(self):
return len(self.valid_data)
def __getitem__(self, idx):
sentence1 = str(self.sentence1_list[idx])
sentence2 = str(self.sentence2_list[idx])
label = self.labels[idx]
encoding = self.tokenizer(
sentence1,
sentence2,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
def load_training_data(train_file):
"""加载训练数据"""
try:
with open(train_file, 'r', encoding='utf-8') as f:
train_data = json.load(f)
logger.info(f"成功加载训练数据: {len(train_data)} 条记录")
return train_data
except Exception as e:
logger.error(f"加载训练数据失败: {str(e)}")
return None
def analyze_data_distribution(data):
"""分析数据分布并计算优化的Focal Loss参数"""
valid_data = [item for item in data if item['label'] in [0, 1]]
label_counts = {}
for item in valid_data:
label = item['label']
label_counts[label] = label_counts.get(label, 0) + 1
total_samples = len(valid_data)
logger.info("=== 训练数据分布分析 ===")
logger.info(f"总有效记录数: {total_samples}")
class_proportions = {}
alpha_weights = []
for label in [0, 1]:
count = label_counts.get(label, 0)
proportion = count / total_samples
class_proportions[label] = proportion
label_name = "同段落" if label == 0 else "不同段落"
logger.info(f"标签 {label} ({label_name}): {count} 条 ({proportion * 100:.2f}%)")
minority_ratio = min(class_proportions.values())
imbalance_ratio = max(class_proportions.values()) / minority_ratio
logger.info(f"\n📊 数据不平衡分析:")
logger.info(f" 🔹 少数类比例: {minority_ratio * 100:.2f}%")
logger.info(f" 🔹 不平衡比率: {imbalance_ratio:.2f}:1")
if imbalance_ratio > 5:
alpha_weights = [0.1, 0.9]
logger.info(" 🎯 使用激进的alpha权重设置")
else:
alpha_weights = [1.0 - class_proportions[0], 1.0 - class_proportions[1]]
if imbalance_ratio > 6:
recommended_gamma = 3.5
logger.info(" 严重不平衡 - 使用Gamma=3.5强化聚焦")
elif imbalance_ratio > 4:
recommended_gamma = 3.0
logger.info(" 中度偏严重不平衡 - 使用Gamma=3.0")
else:
recommended_gamma = 2.5
logger.info(f"\n🎯 优化的Focal Loss参数设置:")
logger.info(f" 🔹 Alpha权重: [多数类={alpha_weights[0]:.3f}, 少数类={alpha_weights[1]:.3f}]")
logger.info(f" 🔹 优化Gamma: {recommended_gamma} (增强难样本聚焦)")
logger.info(f" 🔹 公式: FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)")
logger.info(f" 🔹 加权采样: 启用WeightedRandomSampler")
return label_counts, alpha_weights, recommended_gamma
def compute_metrics(eval_pred):
"""计算评估指标"""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
accuracy = accuracy_score(labels, predictions)
return {
'accuracy': accuracy,
}
class FocalLoss(nn.Module):
"""优化的Focal Loss用于处理类别不平衡问题"""
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力机制"""
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class RoBERTaWithScaledAttentionAndFocalLoss(nn.Module):
"""带缩放点积注意力和优化Focal Loss的RoBERTa模型"""
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0):
super(RoBERTaWithScaledAttentionAndFocalLoss, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
"""初始化新增层的权重"""
nn.init.normal_(self.classifier.weight, std=0.02)
nn.init.zeros_(self.classifier.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
logits = self.classifier(cls_output)
loss = None
if labels is not None:
loss = self.focal_loss(logits, labels)
return {
'loss': loss,
'logits': logits,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def save_pretrained(self, save_directory):
"""保存模型"""
os.makedirs(save_directory, exist_ok=True)
model_to_save = self.module if hasattr(self, 'module') else self
torch.save(model_to_save.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
config_dict = {
'model_type': 'RoBERTaWithScaledAttentionAndFocalLoss',
'base_model': 'chinese-roberta-wwm-ext',
'num_labels': self.config.num_labels,
'hidden_size': self.config.hidden_size,
'focal_alpha': self.focal_alpha.tolist() if self.focal_alpha is not None else None,
'focal_gamma': self.focal_gamma,
'has_scaled_attention': True,
'has_focal_loss': True,
'optimization_level': 'high_priority_v100'
}
with open(os.path.join(save_directory, 'config.json'), 'w', encoding='utf-8') as f:
json.dump(config_dict, f, ensure_ascii=False, indent=2)
class WeightedTrainer(Trainer):
"""自定义Trainer支持WeightedRandomSampler"""
def __init__(self, weighted_sampler=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weighted_sampler = weighted_sampler
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
if self.weighted_sampler is not None:
train_sampler = self.weighted_sampler
else:
train_sampler = self._get_train_sampler()
return DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
def resume_training_from_checkpoint(train_data,
checkpoint_path="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result/checkpoint-86175",
model_path="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model",
output_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train",
checkpoint_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result"):
"""从checkpoint-86175恢复训练"""
gpu_available, gpu_memory = check_gpu_availability()
device = torch.device('cuda')
logger.info(f"🚀 使用GPU设备: {device}")
# 检查checkpoint是否存在
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"❌ Checkpoint路径不存在: {checkpoint_path}")
logger.info(f"📂 找到checkpoint: {checkpoint_path}")
# 列出checkpoint内容
try:
checkpoint_files = os.listdir(checkpoint_path)
logger.info(f"📋 Checkpoint包含文件: {checkpoint_files}")
except Exception as e:
logger.warning(f" 无法列出checkpoint文件: {str(e)}")
# 数据分布分析和优化的Focal Loss参数计算
label_distribution, alpha_weights, recommended_gamma = analyze_data_distribution(train_data)
alpha_tensor = torch.tensor(alpha_weights, dtype=torch.float).to(device)
logger.info(f"📥 加载Chinese-RoBERTa-WWM-Ext模型: {model_path}")
tokenizer = BertTokenizer.from_pretrained(model_path)
model = RoBERTaWithScaledAttentionAndFocalLoss(
model_path=model_path,
num_labels=2,
dropout=0.1,
focal_alpha=alpha_tensor,
focal_gamma=recommended_gamma
)
model = model.to(device)
logger.info(f"✅ 模型已加载到GPU: {next(model.parameters()).device}")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"📊 模型参数统计:")
logger.info(f" 🔹 总参数量: {total_params:,}")
logger.info(f" 🔹 可训练参数: {trainable_params:,}")
logger.info("📋 准备训练数据集和加权采样器...")
train_dataset = SentencePairDataset(train_data, tokenizer, max_length=512)
weighted_sampler = train_dataset.get_weighted_sampler()
logger.info(f" 🔹 训练集大小: {len(train_dataset)}")
logger.info(f" 🔹 类别权重: {train_dataset.class_weights}")
# V100 48GB优化配置(保持原有参数)
batch_size = 16
gradient_accumulation = 2
max_grad_norm = 1.0
fp16 = True
dataloader_num_workers = 4
effective_batch_size = batch_size * gradient_accumulation
initial_learning_rate = 2e-5
warmup_ratio = 0.15
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=checkpoint_dir,
num_train_epochs=100,
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation,
eval_strategy="no",
save_strategy="epoch",
save_steps=20,
logging_strategy="steps",
logging_steps=50,
warmup_ratio=warmup_ratio,
weight_decay=0.01,
learning_rate=initial_learning_rate,
load_best_model_at_end=False,
remove_unused_columns=False,
dataloader_pin_memory=True,
fp16=fp16,
dataloader_num_workers=dataloader_num_workers,
group_by_length=True,
report_to=[],
adam_epsilon=1e-8,
max_grad_norm=max_grad_norm,
save_total_limit=5,
skip_memory_metrics=True,
disable_tqdm=False,
lr_scheduler_type="cosine",
warmup_steps=0,
)
logger.info(f"🔄 从checkpoint恢复训练参数:")
logger.info(f" 🔹 Checkpoint路径: {checkpoint_path}")
logger.info(f" 🔹 训练轮数: {training_args.num_train_epochs}")
logger.info(f" 🔹 批次大小: {batch_size}")
logger.info(f" 🔹 梯度累积: {gradient_accumulation}")
logger.info(f" 🔹 有效批次大小: {effective_batch_size}")
logger.info(f" 🔹 学习率: {training_args.learning_rate}")
logger.info(f" 🔹 预热比例: {warmup_ratio}")
logger.info(f" 🔹 序列长度: 512")
logger.info(f" 🔹 混合精度: {fp16}")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
loss_tracker = LossTracker()
confusion_matrix_callback = ConfusionMatrixCallback(
eval_dataset=train_dataset,
tokenizer=tokenizer,
output_dir=checkpoint_dir,
epochs_interval=20
)
trainer = WeightedTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[loss_tracker, confusion_matrix_callback],
weighted_sampler=weighted_sampler
)
logger.info("🔄 从checkpoint-86175恢复训练...")
logger.info(f"📍 恢复点: {checkpoint_path}")
logger.info("⚡ 将继续使用相同的优化参数:")
logger.info(" ✅ Focal Loss Gamma: 3.0-3.5")
logger.info(" ✅ Alpha权重: [0.1, 0.9]")
logger.info(" ✅ 学习率: 2e-5")
logger.info(" ✅ 预热比例: 15%")
logger.info(" ✅ WeightedRandomSampler")
logger.info(" ✅ 余弦退火学习率调度")
start_time = datetime.now()
try:
# 关键修改:从指定checkpoint恢复训练
trainer.train(resume_from_checkpoint=checkpoint_path)
except RuntimeError as e:
if "out of memory" in str(e).lower():
logger.error("❌ GPU内存不足!")
logger.error("💡 建议减小批次大小")
raise
else:
raise
end_time = datetime.now()
training_duration = end_time - start_time
logger.info(f"🎉 从checkpoint-86175恢复训练完成! 耗时: {training_duration}")
logger.info("📈 生成训练可视化图表...")
plot_training_curves(loss_tracker, checkpoint_dir)
logger.info(f"💾 保存恢复训练后的模型到: {output_dir}")
# 保存到指定的模型输出目录
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# 保存损失历史到checkpoints目录
loss_history = {
'train_losses': loss_tracker.train_losses,
'train_steps': loss_tracker.train_steps,
'resumed_from_checkpoint': checkpoint_path,
'resume_time': start_time.isoformat()
}
with open(os.path.join(checkpoint_dir, 'loss_history_resumed.json'), 'w', encoding='utf-8') as f:
json.dump(loss_history, f, ensure_ascii=False, indent=2)
# 保存混淆矩阵历史到checkpoints目录
cm_history = {epoch: cm.tolist() for epoch, cm in confusion_matrix_callback.confusion_matrices.items()}
with open(os.path.join(checkpoint_dir, 'confusion_matrix_history_resumed.json'), 'w', encoding='utf-8') as f:
json.dump(cm_history, f, ensure_ascii=False, indent=2)
# 保存恢复训练的详细信息
resume_training_info = {
"model_name": model_path,
"model_type": "Chinese-RoBERTa-WWM-Ext with Optimized Focal Loss and Weighted Sampling",
"training_mode": "resumed_from_checkpoint",
"checkpoint_path": checkpoint_path,
"resume_time": start_time.isoformat(),
"training_duration": str(training_duration),
"num_train_samples": len(train_dataset),
"label_distribution": label_distribution,
"data_imbalance": {
"class_0_count": label_distribution.get(0, 0),
"class_1_count": label_distribution.get(1, 0),
"class_0_ratio": label_distribution.get(0, 0) / len(train_dataset),
"class_1_ratio": label_distribution.get(1, 0) / len(train_dataset),
"imbalance_ratio": label_distribution.get(0, 1) / label_distribution.get(1, 1)
},
"optimized_focal_loss_params": {
"alpha_weights": alpha_weights,
"gamma": recommended_gamma,
"formula": "FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)",
"optimization": "aggressive_minority_class_focus"
},
"weighted_sampling": {
"enabled": True,
"class_weights": train_dataset.class_weights,
"sampler_type": "WeightedRandomSampler"
},
"optimized_learning_strategy": {
"initial_learning_rate": initial_learning_rate,
"warmup_ratio": warmup_ratio,
"lr_scheduler": "cosine",
"improvement": "optimized_for_v100"
},
"gpu_optimization": {
"gpu_name": torch.cuda.get_device_name(0),
"gpu_memory_gb": gpu_memory,
"optimization_target": "V100_48GB",
"effective_batch_size": effective_batch_size,
"sequence_length": 512,
"batch_size_optimization": "v100_optimized"
},
"training_args": {
"num_train_epochs": training_args.num_train_epochs,
"per_device_train_batch_size": training_args.per_device_train_batch_size,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"learning_rate": training_args.learning_rate,
"warmup_ratio": training_args.warmup_ratio,
"weight_decay": training_args.weight_decay,
"fp16": training_args.fp16,
"lr_scheduler_type": training_args.lr_scheduler_type
},
"model_parameters": {
"total_params": total_params,
"trainable_params": trainable_params,
},
"paths": {
"model_input_path": model_path,
"model_output_path": output_dir,
"checkpoint_output_path": checkpoint_dir,
"resume_checkpoint_path": checkpoint_path,
"data_path": "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data"
},
"resume_optimizations": [
"Resumed from checkpoint-86175",
"Maintained Focal Loss Gamma: 3.0-3.5",
"Maintained Alpha weights: [0.1, 0.9]",
"Maintained learning rate: 2e-5",
"Maintained warmup ratio: 15%",
"Maintained WeightedRandomSampler",
"Maintained cosine annealing scheduler",
"Maintained V100 48GB optimized batch size: 16",
"Maintained full sequence length: 512 tokens"
],
"visualization_files_resumed": {
"training_curves": "training_curves.png",
"confusion_matrices": [f"confusion_matrix_epoch_{i}.png" for i in range(20, 101, 20)] + [
"confusion_matrix_epoch_100.png"],
"loss_history": "loss_history_resumed.json",
"confusion_matrix_history": "confusion_matrix_history_resumed.json"
},
"training_completed": end_time.isoformat()
}
with open(os.path.join(checkpoint_dir, 'resume_training_info.json'), 'w', encoding='utf-8') as f:
json.dump(resume_training_info, f, ensure_ascii=False, indent=2)
# 同时在模型目录保存一份配置信息
with open(os.path.join(output_dir, 'resume_training_summary.json'), 'w', encoding='utf-8') as f:
json.dump(resume_training_info, f, ensure_ascii=False, indent=2)
logger.info("📋 恢复训练信息已保存")
return trainer, model, tokenizer, loss_tracker, confusion_matrix_callback
def main():
"""主函数 - 从checkpoint-86175恢复训练"""
logger.info("=" * 120)
logger.info("🔄 从Checkpoint-86175恢复Chinese-RoBERTa-WWM-Ext训练")
logger.info("=" * 120)
# 配置路径
train_file = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data/train_dataset.json"
model_path = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model"
output_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train"
checkpoint_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result"
resume_checkpoint = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result/checkpoint-86175"
# 确保所有输出目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"📁 确保输出目录存在:")
logger.info(f" 🔹 模型输出: {output_dir}")
logger.info(f" 🔹 训练记录: {checkpoint_dir}")
# 检查checkpoint是否存在
if not os.path.exists(resume_checkpoint):
logger.error(f"❌ 指定的checkpoint不存在: {resume_checkpoint}")
logger.error("💡 请检查checkpoint路径是否正确")
return
else:
logger.info(f"✅ 找到恢复checkpoint: {resume_checkpoint}")
# 确认第三方报告工具已禁用
logger.info("🚫 确认第三方报告工具状态:")
logger.info(f" 🔹 WANDB_DISABLED: {os.environ.get('WANDB_DISABLED', 'not set')}")
logger.info(f" 🔹 COMET_MODE: {os.environ.get('COMET_MODE', 'not set')}")
logger.info(f" 🔹 NEPTUNE_MODE: {os.environ.get('NEPTUNE_MODE', 'not set')}")
logger.info(f" ✅ 所有第三方报告工具已禁用")
logger.info(f"\n📋 恢复训练配置:")
logger.info(f" 🔹 训练数据: {train_file}")
logger.info(f" 🔹 基础模型: {model_path}")
logger.info(f" 🔹 恢复checkpoint: {resume_checkpoint}")
logger.info(f" 🔹 模型类型: Chinese-RoBERTa-WWM-Ext")
logger.info(f" 🔹 训练模式: 从checkpoint恢复")
logger.info(f" 🔹 保持所有优化参数不变:")
logger.info(f" • Focal Loss Gamma: 3.0+ (增强难样本聚焦)")
logger.info(f" • Alpha权重: [0.1, 0.9] (激进的少数类关注)")
logger.info(f" • 学习率: 2e-5 (V100优化)")
logger.info(f" • 批次大小: 16 (V100大显存优化)")
logger.info(f" • 序列长度: 512 (完整长度)")
logger.info(f" • WeightedRandomSampler (平衡采样)")
logger.info(f" • 余弦退火学习率调度")
logger.info(f" 🔹 目标轮数: 100 epochs")
logger.info(f" 🔹 模型输出: {output_dir}")
logger.info(f" 🔹 训练记录: {checkpoint_dir}")
# 加载训练数据
train_data = load_training_data(train_file)
if train_data is None:
logger.error("❌ 无法加载训练数据,程序退出")
return
try:
# 从checkpoint恢复训练
trainer, model, tokenizer, loss_tracker, cm_callback = resume_training_from_checkpoint(
train_data,
checkpoint_path=resume_checkpoint,
model_path=model_path,
output_dir=output_dir,
checkpoint_dir=checkpoint_dir
)
logger.info("=" * 120)
logger.info("🎉 从Checkpoint-86175恢复训练完成!")
logger.info("=" * 120)
logger.info(f"📁 文件输出位置:")
logger.info(f" 🔹 训练好的模型: {output_dir}")
logger.info(f" 🔹 训练记录和图表: {checkpoint_dir}")
logger.info("📄 生成的文件:")
logger.info(" 模型文件 (model_train目录):")
logger.info(" • pytorch_model.bin - 恢复训练后的模型权重")
logger.info(" • config.json - 优化模型配置")
logger.info(" • tokenizer配置文件")
logger.info(" • resume_training_summary.json - 恢复训练摘要")
logger.info(" 训练记录 (ouput_result目录):")
logger.info(" • resume_training_info.json - 详细恢复训练信息")
logger.info(" • loss_history_resumed.json - 恢复训练损失历史")
logger.info(" • confusion_matrix_history_resumed.json - 恢复训练混淆矩阵历史")
logger.info(" • training_curves.png - 训练损失曲线(更新)")
logger.info(" • confusion_matrix_epoch_X.png - 各epoch混淆矩阵(更新)")
logger.info(" • checkpoint-* - 新的训练检查点")
logger.info("🔄 恢复训练特性:")
logger.info(" ✅ 从checkpoint-86175成功恢复")
logger.info(" ✅ 保持所有原有优化参数")
logger.info(" ✅ 继续使用激进的Focal Loss设置")
logger.info(" ✅ 继续使用WeightedRandomSampler")
logger.info(" ✅ 继续使用V100优化配置")
logger.info(" ✅ 继续使用余弦退火学习率调度")
logger.info(" ✅ 保持完整的可视化监控")
logger.info("🎯 恢复训练优势:")
logger.info(" ⚡ 无缝继续之前的训练进度")
logger.info(" ⚡ 保持学习率调度状态")
logger.info(" ⚡ 保持优化器状态")
logger.info(" ⚡ 保持所有超参数设置")
logger.info(" ⚡ 继续数据不平衡优化策略")
# 显示完整保存路径列表
logger.info(f"\n📂 文件保存详情:")
logger.info(f"📋 模型文件 ({output_dir}):")
try:
for file in os.listdir(output_dir):
file_path = os.path.join(output_dir, file)
if os.path.isfile(file_path):
file_size = os.path.getsize(file_path) / (1024 * 1024)
logger.info(f" 📄 {file} ({file_size:.2f} MB)")
except Exception as e:
logger.warning(f" 无法列出模型文件: {str(e)}")
logger.info(f"📋 训练记录 ({checkpoint_dir}):")
try:
files = os.listdir(checkpoint_dir)
# 按类型分组显示
png_files = [f for f in files if f.endswith('.png')]
json_files = [f for f in files if f.endswith('.json')]
checkpoint_dirs = [f for f in files if f.startswith('checkpoint-')]
other_files = [f for f in files if f not in png_files + json_files + checkpoint_dirs]
if json_files:
logger.info(" JSON配置文件:")
for file in sorted(json_files):
file_path = os.path.join(checkpoint_dir, file)
file_size = os.path.getsize(file_path) / 1024
marker = " (NEW)" if "resumed" in file else ""
logger.info(f" 📄 {file} ({file_size:.1f} KB){marker}")
if png_files:
logger.info(" 可视化图表:")
for file in sorted(png_files):
file_path = os.path.join(checkpoint_dir, file)
file_size = os.path.getsize(file_path) / 1024
logger.info(f" 📊 {file} ({file_size:.1f} KB)")
if checkpoint_dirs:
logger.info(" 训练检查点:")
for dir_name in sorted(checkpoint_dirs):
marker = " (RESUME FROM)" if dir_name == "checkpoint-86175" else ""
logger.info(f" 📁 {dir_name}/{marker}")
if other_files:
logger.info(" 其他文件:")
for file in sorted(other_files):
file_path = os.path.join(checkpoint_dir, file)
if os.path.isfile(file_path):
file_size = os.path.getsize(file_path) / (1024 * 1024)
logger.info(f" 📄 {file} ({file_size:.2f} MB)")
except Exception as e:
logger.warning(f" 无法列出训练记录: {str(e)}")
logger.info("\n🎯 恢复训练完成,可以继续评估模型性能!")
logger.info("💡 提示: 新生成的文件名包含'resumed'标识")
except Exception as e:
logger.error(f"❌ 从checkpoint恢复训练过程中出现错误: {str(e)}")
import traceback
traceback.print_exc()
raise
if __name__ == "__main__":
main()

1208
5.AI标注-model_trian/全参微调/无验证集训练/模型按验证集准确率选择/Bert-train-plaus.py

File diff suppressed because it is too large Load Diff

1165
5.AI标注-model_trian/全参微调/有验证集训练/Bert-test_eval.py

File diff suppressed because it is too large Load Diff

1169
5.AI标注-model_trian/全参微调/有验证集训练/Bert-testeval_continue.py

File diff suppressed because it is too large Load Diff

1401
5.AI标注-model_trian/全参微调/有验证集训练/模型按验证集准确率选择/Bert-test_evalplaus-continue.py

File diff suppressed because it is too large Load Diff

1296
5.AI标注-model_trian/全参微调/有验证集训练/模型按验证集准确率选择/Bert-test_evalplaus.py

File diff suppressed because it is too large Load Diff

656
5.AI标注-model_trian/封装模型/Fastapi.py

@ -0,0 +1,656 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import re
import time
import concurrent.futures
from typing import Dict, List, Optional, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import BertTokenizer, BertModel
import uvicorn
# ========== 双路径边界分类器模型定义部分 ==========
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class DualPathBoundaryClassifier(nn.Module):
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# 双路径分类器
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# 边界强制权重
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
nn.init.normal_(self.regular_classifier.weight, std=0.02)
nn.init.zeros_(self.regular_classifier.bias)
nn.init.normal_(self.boundary_classifier.weight, std=0.02)
nn.init.zeros_(self.boundary_classifier.bias)
nn.init.normal_(self.boundary_detector.weight, std=0.02)
nn.init.zeros_(self.boundary_detector.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# 缩放点积注意力增强
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# 双路径分类
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# 边界检测
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# 动态融合
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
loss = None
if labels is not None:
regular_loss = self.focal_loss(regular_logits, labels)
boundary_loss = self.focal_loss(boundary_logits, labels)
final_loss = self.focal_loss(final_logits, labels)
boundary_labels = self._generate_boundary_labels(labels)
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels)
total_loss = (0.4 * final_loss +
0.3 * regular_loss +
0.2 * boundary_loss +
0.1 * detection_loss)
loss = total_loss
return {
'loss': loss,
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def _generate_boundary_labels(self, labels):
boundary_labels = labels.float()
noise = torch.rand_like(boundary_labels) * 0.1
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0)
return boundary_labels
# ========== 模型加载部分 ==========
def check_gpu_availability():
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
print(f"🚀 GPU: {gpu_name} ({gpu_memory:.1f} GB)")
return torch.device('cuda')
else:
print("🔄 使用CPU")
return torch.device('cpu')
def safe_convert_focal_alpha(focal_alpha, device):
if focal_alpha is None:
return None
try:
if isinstance(focal_alpha, torch.Tensor):
return focal_alpha.to(device)
elif isinstance(focal_alpha, (list, tuple)):
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
elif isinstance(focal_alpha, np.ndarray):
return torch.from_numpy(focal_alpha).float().to(device)
elif isinstance(focal_alpha, (int, float)):
return torch.tensor([focal_alpha], dtype=torch.float32).to(device)
else:
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
except Exception as e:
print(f" 转换focal_alpha失败: {e}")
return None
def load_trained_dual_path_model(model_path, device, original_roberta_path=None):
try:
if original_roberta_path and os.path.exists(original_roberta_path):
tokenizer_path = original_roberta_path
else:
tokenizer_path = model_path
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
focal_gamma = config.get('focal_gamma', 3.0)
focal_alpha_raw = config.get('focal_alpha', None)
boundary_force_weight = config.get('boundary_force_weight', 2.0)
else:
focal_gamma = 3.0
focal_alpha_raw = None
boundary_force_weight = 2.0
focal_alpha = safe_convert_focal_alpha(focal_alpha_raw, device)
if original_roberta_path and os.path.exists(original_roberta_path):
model_init_path = original_roberta_path
else:
model_init_path = tokenizer_path
model = DualPathBoundaryClassifier(
model_path=model_init_path,
num_labels=2,
dropout=0.1,
focal_alpha=focal_alpha,
focal_gamma=focal_gamma,
boundary_force_weight=boundary_force_weight
)
model_vocab_size = model.roberta.embeddings.word_embeddings.weight.shape[0]
if model_vocab_size != tokenizer.vocab_size:
raise ValueError(f"词汇表大小不匹配: 模型={model_vocab_size}, Tokenizer={tokenizer.vocab_size}")
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
try:
state_dict = torch.load(model_weights_path, map_location=device)
except Exception:
state_dict = torch.load(model_weights_path, map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys or unexpected_keys:
print(f" 加载权重时有差异: 缺少{len(missing_keys)}个键, 多余{len(unexpected_keys)}个键")
else:
raise FileNotFoundError(f"未找到模型权重文件: {model_weights_path}")
model.to(device)
model.eval()
print("✅ 双路径边界分类器模型加载完成")
return model, tokenizer
except Exception as e:
print(f"❌ 双路径模型加载失败: {str(e)}")
raise
# ========== 文本处理部分 ==========
def split_text_into_sentences(text: str) -> List[str]:
text = text.strip()
sentence_endings = r'([。!?;])'
parts = re.split(sentence_endings, text)
sentences = []
for i in range(0, len(parts), 2):
if i < len(parts):
sentence = parts[i].strip()
if sentence:
if i + 1 < len(parts):
sentence += parts[i + 1]
sentences.append(sentence)
return sentences
def predict_sentence_pairs_dual_path(sentences: List[str], model, tokenizer, device, max_length=384) -> Dict[str, str]:
if len(sentences) < 2:
return {"paragraph_1": sentences[0] if sentences else ""}
results = {}
current_paragraph_sentences = [sentences[0]]
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1_clean = re.sub(r'[。!?;]$', '', sentences[i])
sentence2_clean = re.sub(r'[。!?;]$', '', sentences[i + 1])
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
prediction = torch.argmax(logits, dim=-1).item()
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
return results
def process_single_broadcast_dual_path(text: str, broadcast_id: Optional[str] = None) -> dict:
try:
if not text or not text.strip():
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": "文本为空"
}
sentences = split_text_into_sentences(text)
if len(sentences) == 0:
return {
"broadcast_id": broadcast_id,
"segments": {"paragraph_1": text},
"status": "success"
}
segments = predict_sentence_pairs_dual_path(sentences, model, tokenizer, device)
return {
"broadcast_id": broadcast_id,
"segments": segments,
"status": "success"
}
except Exception as e:
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": str(e)
}
# ========== FastAPI应用部分 ==========
app = FastAPI(title="双路径边界分类器文本分段服务", version="3.0.0")
# 全局变量存储模型
model = None
tokenizer = None
device = None
# ========== 请求和响应模型 ==========
class TextInput(BaseModel):
广播内容: str
class BroadcastItem(BaseModel):
广播内容: str
广播ID: Optional[str] = None
BatchInput = List[BroadcastItem]
# ========== 生命周期事件 ==========
@app.on_event("startup")
async def load_model():
global model, tokenizer, device
model_path = "/work/model_robert/model_train-eval"
original_roberta_path = "/work/model_robert/model"
print("🚀 正在启动双路径边界分类器文本分段服务...")
try:
device = check_gpu_availability()
model, tokenizer = load_trained_dual_path_model(model_path, device, original_roberta_path)
print(f"✅ 双路径边界分类器模型加载成功! 设备: {device}")
# print(f"📝 词汇表大小: {tokenizer.vocab_size}")
# print(f"🎯 序列最大长度: 384 tokens")
except Exception as e:
print(f"❌ 双路径模型加载失败: {e}")
raise
# ========== API接口 ==========
@app.get("/")
async def root():
"""根路径 - 服务状态检查"""
return {
"service": "双路径边界分类器文本分段服务",
"status": "运行中" if model is not None else "模型未加载",
"version": "3.0.0",
"model_type": "DualPathBoundaryClassifier",
"features": [
"双路径架构: 常规分类器 + 边界分类器",
"边界检测器: 纯神经网络学习边界模式",
"动态权重融合: 自适应边界识别",
"序列长度: 384 tokens",
"单条处理",
"批量处理(简化)",
"详细分析"
]
}
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {
"status": "healthy" if model is not None else "unhealthy",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"max_sequence_length": 384,
"boundary_detection": "pure_neural_network"
}
@app.post("/segment_simple")
async def segment_text_simple(input_data: TextInput):
"""单条文本分段接口"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
result = process_single_broadcast_dual_path(input_data.广播内容)
if result["status"] == "success":
return result["segments"]
else:
return {"error": result["error"]}
except Exception as e:
return {"error": f"双路径处理失败: {str(e)}"}
@app.post("/segment_batch_simple")
async def segment_batch_simple(broadcasts: BatchInput):
"""批量文本分段接口 - 简化输出"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
if not broadcasts:
return {"error": "广播列表不能为空"}
start_time = time.time()
results = []
# 并行处理
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
future_to_broadcast = {
executor.submit(process_single_broadcast_dual_path,
broadcast.广播内容,
broadcast.广播ID): broadcast
for broadcast in broadcasts
}
for future in concurrent.futures.as_completed(future_to_broadcast):
try:
result = future.result()
results.append(result)
except Exception as e:
broadcast = future_to_broadcast[future]
results.append({
"broadcast_id": broadcast.广播ID,
"status": "failed",
"error": f"双路径处理异常: {str(e)}"
})
# 简化输出格式
simplified_results = {}
success_count = 0
for i, result in enumerate(results):
key = result.get("broadcast_id") or f"broadcast_{i + 1}"
if result["status"] == "success":
simplified_results[key] = result["segments"]
success_count += 1
else:
simplified_results[key] = {"error": result.get("error", "双路径处理失败")}
total_time = time.time() - start_time
return {
"model": "DualPathBoundaryClassifier",
"total": len(results),
"success": success_count,
"failed": len(results) - success_count,
"processing_time": round(total_time, 3),
"results": simplified_results
}
except Exception as e:
return {"error": f"双路径批量处理失败: {str(e)}"}
@app.post("/segment_with_details")
async def segment_with_details(input_data: TextInput):
"""带详细信息的文本分段接口"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
text = input_data.广播内容
sentences = split_text_into_sentences(text)
if len(sentences) < 2:
return {
"segments": {"paragraph_1": text},
"sentence_details": [],
"total_sentences": len(sentences)
}
results = {}
current_paragraph_sentences = [sentences[0]]
sentence_details = []
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1_clean = re.sub(r'[。!?;]$', '', sentences[i])
sentence2_clean = re.sub(r'[。!?;]$', '', sentences[i + 1])
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=384,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
regular_logits = outputs['regular_logits']
boundary_logits = outputs['boundary_logits']
boundary_score = outputs['boundary_score'].item()
prediction = torch.argmax(logits, dim=-1).item()
probability = torch.softmax(logits, dim=-1)
regular_prob = torch.softmax(regular_logits, dim=-1)
boundary_prob = torch.softmax(boundary_logits, dim=-1)
detail = {
"sentence_pair_index": i + 1,
"sentence1": sentences[i],
"sentence2": sentences[i + 1],
"prediction": prediction,
"prediction_label": "same_paragraph" if prediction == 0 else "different_paragraph",
"final_probabilities": {
"same_paragraph": float(probability[0][0]),
"different_paragraph": float(probability[0][1])
},
"regular_path_probabilities": {
"same_paragraph": float(regular_prob[0][0]),
"different_paragraph": float(regular_prob[0][1])
},
"boundary_path_probabilities": {
"same_paragraph": float(boundary_prob[0][0]),
"different_paragraph": float(boundary_prob[0][1])
},
"boundary_score": boundary_score,
"boundary_confidence": "high" if boundary_score > 0.7 else "medium" if boundary_score > 0.3 else "low"
}
sentence_details.append(detail)
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
boundary_scores = [d["boundary_score"] for d in sentence_details]
return {
"segments": results,
"sentence_details": sentence_details,
"total_sentences": len(sentences),
"total_pairs_analyzed": len(sentence_details),
"boundary_statistics": {
"average_boundary_score": round(sum(boundary_scores) / len(boundary_scores),
4) if boundary_scores else 0,
"max_boundary_score": round(max(boundary_scores), 4) if boundary_scores else 0,
"min_boundary_score": round(min(boundary_scores), 4) if boundary_scores else 0
}
}
except Exception as e:
return {"error": f"详细分析失败: {str(e)}"}
@app.get("/stats")
async def get_processing_stats():
"""获取处理统计信息"""
return {
"service_status": "running" if model is not None else "down",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"vocab_size": tokenizer.vocab_size if tokenizer else None,
"max_sequence_length": 384,
"api_endpoints": [
"/segment_simple - 单条处理",
"/segment_batch_simple - 批量处理(简化)",
"/segment_with_details - 带详细信息分段"
]
}
# ========== 启动配置 ==========
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8888)

996
5.AI标注-model_trian/封装模型/Project1/app.py

@ -0,0 +1,996 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import os
import math
import logging
import re
from datetime import datetime
from typing import List, Dict, Any, Optional
from collections import Counter
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import BertTokenizer, BertModel
import uvicorn
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 全局变量
model = None
tokenizer = None
device = None
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力机制"""
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class DualPathBoundaryClassifier(nn.Module):
"""双路径边界分类器,完全依靠神经网络学习边界模式"""
def __init__(self, model_path, num_labels=2, dropout=0.1, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# 双路径分类器
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# 边界强制权重
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# 缩放点积注意力增强
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# 双路径分类
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# 边界检测
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# 动态融合
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
return {
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def load_model():
"""加载训练好的模型"""
global model, tokenizer, device
# 检查GPU
if torch.cuda.is_available():
device = torch.device('cuda')
gpu_name = torch.cuda.get_device_name(0)
logger.info(f"✅ 使用GPU: {gpu_name}")
else:
device = torch.device('cpu')
logger.info(" 使用CPU运行")
# 模型路径配置
model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train"
original_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model"
logger.info(f"📥 加载双路径边界分类器模型...")
try:
# 先检查训练模型目录是否存在配置文件
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
model_config = json.load(f)
boundary_force_weight = model_config.get('boundary_force_weight', 2.0)
logger.info(f" 🔹 边界强制权重: {boundary_force_weight}")
else:
boundary_force_weight = 2.0
# 先尝试从训练目录加载tokenizer,如果失败则使用原始目录
try:
tokenizer = BertTokenizer.from_pretrained(model_path)
logger.info(f" ✅ 从训练目录加载tokenizer成功")
except Exception as e:
logger.warning(f" 从训练目录加载tokenizer失败: {str(e)}")
tokenizer = BertTokenizer.from_pretrained(original_model_path)
logger.info(f" ✅ 从原始目录加载tokenizer成功")
logger.info(f" 🔹 词汇表大小: {len(tokenizer.vocab)}")
# 创建模型实例,使用原始模型路径以避免词汇表不匹配
logger.info(f" 🔧 使用原始模型路径创建模型实例")
model = DualPathBoundaryClassifier(
model_path=original_model_path, # 强制使用原始模型路径
num_labels=2,
dropout=0.1,
boundary_force_weight=boundary_force_weight
)
# 加载训练好的权重
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
logger.info(f" 📥 加载训练权重...")
state_dict = torch.load(model_weights_path, map_location=device)
# 尝试加载权重,如果失败则使用更安全的方法
try:
model.load_state_dict(state_dict)
logger.info(f" ✅ 成功加载完整权重")
except RuntimeError as e:
if "size mismatch" in str(e):
logger.warning(f" 检测到权重尺寸不匹配,使用兼容性加载")
# 过滤掉不匹配的权重
model_dict = model.state_dict()
filtered_dict = {}
for k, v in state_dict.items():
if k in model_dict:
if model_dict[k].shape == v.shape:
filtered_dict[k] = v
else:
logger.warning(
f" 跳过不匹配的权重: {k} (模型: {model_dict[k].shape}, 检查点: {v.shape})")
else:
logger.warning(f" 跳过未知权重: {k}")
# 加载过滤后的权重
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)
logger.info(f" ✅ 成功加载兼容权重 ({len(filtered_dict)}/{len(state_dict)} 权重已加载)")
else:
raise e
else:
logger.error(f" ❌ 找不到模型权重文件: {model_weights_path}")
return False
model.to(device)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"📊 模型参数: {total_params:,}")
# 测试模型是否正常工作
logger.info("🧪 测试模型推理...")
test_result = test_model_inference()
if test_result:
logger.info("✅ 模型推理测试通过")
logger.info("🚀 模型加载完成,Ready for service!")
return True
else:
logger.error("❌ 模型推理测试失败")
return False
except Exception as e:
logger.error(f"❌ 模型加载失败: {str(e)}")
import traceback
traceback.print_exc()
return False
def test_model_inference():
"""测试模型推理是否正常"""
try:
test_sentences = [
"这是第一个测试句子。",
"这是第二个测试句子。"
]
with torch.no_grad():
inputs = tokenizer(
test_sentences[0],
test_sentences[1],
truncation=True,
padding=True,
max_length=512,
return_tensors='pt'
)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
# 检查输出格式
required_keys = ['logits', 'boundary_score']
for key in required_keys:
if key not in outputs:
logger.error(f"模型输出缺少必要的键: {key}")
return False
logits = outputs['logits']
boundary_score = outputs['boundary_score']
# 检查输出形状
if logits.shape != torch.Size([1, 2]):
logger.error(f"logits形状不正确: {logits.shape}, 期望: [1, 2]")
return False
if boundary_score.shape != torch.Size([1]):
logger.error(f"boundary_score形状不正确: {boundary_score.shape}, 期望: [1]")
return False
# 检查数值范围
if not (0 <= boundary_score.item() <= 1):
logger.error(f"boundary_score超出范围: {boundary_score.item()}")
return False
logger.info(f" 测试预测: logits={logits.tolist()}, boundary_score={boundary_score.item():.3f}")
return True
except Exception as e:
logger.error(f"模型推理测试异常: {str(e)}")
return False
def split_text_into_sentences(text: str) -> List[str]:
"""将文本按句号、感叹号、问号分割成句子"""
# 中文句子分割规则
sentence_endings = r'[。!?!?]'
sentences = re.split(sentence_endings, text)
# 过滤空句子,保留标点符号
result = []
for i, sentence in enumerate(sentences):
sentence = sentence.strip()
if sentence:
# 如果不是最后一个句子,添加标点符号
if i < len(sentences) - 1:
# 找到原始标点符号
original_text = text
start_pos = 0
for j in range(i):
if sentences[j].strip():
start_pos = original_text.find(sentences[j].strip(), start_pos) + len(sentences[j].strip())
# 查找句子后的标点符号
remaining_text = original_text[start_pos:]
punctuation_match = re.search(sentence_endings, remaining_text)
if punctuation_match:
sentence += punctuation_match.group()
result.append(sentence)
return result
def predict_sentence_pairs(sentences: List[str]) -> List[Dict[str, Any]]:
"""预测相邻句子对是否需要分段"""
if len(sentences) < 2:
return []
results = []
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1 = sentences[i]
sentence2 = sentences[i + 1]
# Tokenization
inputs = tokenizer(
sentence1,
sentence2,
truncation=True,
padding=True,
max_length=512,
return_tensors='pt'
)
# 移动到设备
inputs = {k: v.to(device) for k, v in inputs.items()}
# 模型预测
outputs = model(**inputs)
# 获取预测结果
logits = outputs['logits']
boundary_score = outputs['boundary_score']
probs = F.softmax(logits, dim=-1)
prediction = torch.argmax(logits, dim=-1).item()
confidence = torch.max(probs, dim=-1)[0].item()
boundary_score_value = boundary_score.item()
# 结果
result = {
'sentence1': sentence1,
'sentence2': sentence2,
'prediction': prediction, # 0: 同段落, 1: 不同段落
'confidence': confidence,
'boundary_score': boundary_score_value,
'should_split': prediction == 1,
'split_reason': get_split_reason(prediction, boundary_score_value, confidence)
}
results.append(result)
return results
def get_split_reason(prediction: int, boundary_score: float, confidence: float) -> str:
"""生成分段原因说明"""
if prediction == 1:
if boundary_score > 0.7:
return f"检测到强边界信号 (边界分数: {boundary_score:.3f})"
elif confidence > 0.8:
return f"语义转换明显 (置信度: {confidence:.3f})"
else:
return f"建议分段 (置信度: {confidence:.3f})"
else:
return f"内容连贯,无需分段 (置信度: {confidence:.3f})"
def segment_text(text: str) -> Dict[str, Any]:
"""对完整文本进行分段处理"""
# 分割成句子
sentences = split_text_into_sentences(text)
if len(sentences) < 2:
return {
'original_text': text,
'sentences': sentences,
'segments': [text] if text.strip() else [],
'split_decisions': [],
'total_sentences': len(sentences),
'total_segments': 1 if text.strip() else 0
}
# 预测相邻句子对
split_decisions = predict_sentence_pairs(sentences)
# 根据预测结果进行分段
segments = []
current_segment = [sentences[0]]
for i, decision in enumerate(split_decisions):
if decision['should_split']:
# 需要分段,结束当前段落
segments.append(''.join(current_segment))
current_segment = [sentences[i + 1]]
else:
# 不需要分段,继续当前段落
current_segment.append(sentences[i + 1])
# 添加最后一个段落
if current_segment:
segments.append(''.join(current_segment))
return {
'original_text': text,
'sentences': sentences,
'segments': segments,
'split_decisions': split_decisions,
'total_sentences': len(sentences),
'total_segments': len(segments)
}
# FastAPI 应用
app = FastAPI(title="双路径边界分类器文本分段服务", version="1.0.0")
# 请求模型
class TextInput(BaseModel):
text: str
class BatchTextInput(BaseModel):
texts: List[str]
# 响应模型
class SegmentResult(BaseModel):
original_text: str
sentences: List[str]
segments: List[str]
total_sentences: int
total_segments: int
class DetailedSegmentResult(BaseModel):
original_text: str
sentences: List[str]
segments: List[str]
split_decisions: List[Dict[str, Any]]
total_sentences: int
total_segments: int
@app.on_event("startup")
async def startup_event():
"""启动时加载模型"""
logger.info("🚀 启动双路径边界分类器服务...")
success = load_model()
if not success:
logger.error("❌ 模型加载失败,服务无法启动")
raise RuntimeError("模型加载失败")
logger.info("✅ 服务启动成功!")
@app.get("/", response_class=HTMLResponse)
async def get_frontend():
"""返回前端页面"""
html_content = """
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>双路径边界分类器 - 文本分段服务</title>
<style>
body {
font-family: 'Arial', sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
border-radius: 15px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
overflow: hidden;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 30px;
text-align: center;
}
.header h1 {
margin: 0;
font-size: 2.5em;
font-weight: bold;
}
.header p {
margin: 10px 0 0 0;
font-size: 1.2em;
opacity: 0.9;
}
.content {
padding: 30px;
}
.input-section {
margin-bottom: 30px;
}
.input-section label {
display: block;
margin-bottom: 10px;
font-weight: bold;
color: #333;
font-size: 1.1em;
}
.input-section textarea {
width: 100%;
height: 200px;
padding: 15px;
border: 2px solid #ddd;
border-radius: 10px;
font-size: 16px;
font-family: 'Arial', sans-serif;
resize: vertical;
box-sizing: border-box;
}
.input-section textarea:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 10px rgba(102, 126, 234, 0.3);
}
.button-group {
display: flex;
gap: 15px;
margin-top: 20px;
}
.btn {
padding: 12px 25px;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: bold;
cursor: pointer;
transition: all 0.3s ease;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.btn-primary:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
.btn-secondary {
background: #f8f9fa;
color: #333;
border: 2px solid #ddd;
}
.btn-secondary:hover {
background: #e9ecef;
}
.btn:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none !important;
}
.loading {
display: none;
text-align: center;
padding: 20px;
color: #667eea;
font-weight: bold;
}
.loading::after {
content: "";
display: inline-block;
margin-left: 10px;
width: 20px;
height: 20px;
border: 3px solid #f3f3f3;
border-top: 3px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.result-section {
margin-top: 30px;
display: none;
}
.result-header {
background: #f8f9fa;
padding: 15px;
border-radius: 10px;
margin-bottom: 20px;
border-left: 5px solid #667eea;
}
.segment {
background: #f8f9fa;
padding: 15px;
margin: 10px 0;
border-radius: 8px;
border-left: 4px solid #28a745;
position: relative;
}
.segment-header {
font-weight: bold;
color: #28a745;
margin-bottom: 8px;
}
.segment-content {
line-height: 1.6;
color: #333;
}
.split-decision {
background: white;
padding: 10px 15px;
margin: 8px 0;
border-radius: 6px;
border-left: 3px solid #ffc107;
font-size: 14px;
}
.split-decision.split {
border-left-color: #dc3545;
background: #fff5f5;
}
.split-decision.no-split {
border-left-color: #28a745;
background: #f5fff5;
}
.stats {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 15px;
margin: 20px 0;
}
.stat-card {
background: white;
padding: 20px;
border-radius: 10px;
text-align: center;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.stat-number {
font-size: 2em;
font-weight: bold;
color: #667eea;
}
.stat-label {
color: #666;
margin-top: 5px;
}
.demo-buttons {
display: flex;
gap: 10px;
margin-top: 15px;
flex-wrap: wrap;
}
.demo-btn {
padding: 8px 15px;
background: #e9ecef;
border: 1px solid #ddd;
border-radius: 5px;
cursor: pointer;
font-size: 14px;
transition: all 0.2s ease;
}
.demo-btn:hover {
background: #dee2e6;
border-color: #667eea;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🤖 双路径边界分类器</h1>
<p>智能文本分段服务 - 基于神经网络的广播内容自动分段</p>
</div>
<div class="content">
<div class="input-section">
<label for="textInput">📝 请输入需要分段的文本内容</label>
<textarea id="textInput" placeholder="请输入广播内容或其他需要分段的文本...&#10;&#10;示例:&#10;小月提醒大家,不认识的野菜一定不要去采,避免因误食发生过敏或者中毒。好了,健康快车赶快上车,我是小月,我们下期再会。下面即将收听到的是普法档案。欢迎收听普法档案,我是沐白。"></textarea>
<div class="demo-buttons">
<button class="demo-btn" onclick="loadDemo(1)">📻 广播节目示例</button>
<button class="demo-btn" onclick="loadDemo(2)">📰 新闻内容示例</button>
<button class="demo-btn" onclick="loadDemo(3)">📚 教育内容示例</button>
<button class="demo-btn" onclick="clearText()">🗑 清空</button>
</div>
<div class="button-group">
<button class="btn btn-primary" onclick="processText()">🚀 开始分段</button>
<button class="btn btn-secondary" onclick="clearResults()">🔄 清除结果</button>
</div>
</div>
<div class="loading" id="loading">
正在分析文本请稍候...
</div>
<div class="result-section" id="resultSection">
<div class="result-header">
<h3>📊 分段结果</h3>
<div class="stats" id="stats"></div>
</div>
<div id="segments"></div>
<details style="margin-top: 20px;">
<summary style="cursor: pointer; font-weight: bold; color: #667eea;">🔍 查看详细分析过程</summary>
<div id="detailedAnalysis" style="margin-top: 15px;"></div>
</details>
</div>
</div>
</div>
<script>
const demoTexts = {
1: `小月提醒大家不认识的野菜一定不要去采避免因误食发生过敏或者中毒食用野菜前最好留存少许的野菜或者先拍照一旦发生不适要停止食用立即催吐然后携带剩余野菜呕吐物或者之前拍的照片及时就医好了健康快车赶快上车我是小月我们下期再会下面即将收听到的是普法档案欢迎收听普法档案我是沐白今天我们来讨论一个重要的法律问题这个问题涉及到合同纠纷的处理方式感谢大家收听今天的节目内容接下来为您播放轻音乐时光`,
2: `今日上午市政府召开新闻发布会宣布了新的城市规划方案该方案将重点发展科技创新产业预计投资总额达到500亿元据了解新规划将涵盖教育医疗交通等多个领域以上就是今天的新闻内容现在为大家播放天气预报明天将是晴朗的一天气温在15到25度之间请大家注意适当增减衣物`,
3: `在学习语言的过程中我们需要掌握几个重要的原则首先是要多听多说培养语感其次是要注意语法的正确性避免常见错误最后是要多阅读扩大词汇量今天的语言学习课程就到这里下面请收听音乐欣赏节目今天为大家介绍的是古典音乐的魅力音乐能够陶冶情操提升审美能力`
};
function loadDemo(num) {
document.getElementById('textInput').value = demoTexts[num];
}
function clearText() {
document.getElementById('textInput').value = '';
}
function clearResults() {
document.getElementById('resultSection').style.display = 'none';
}
async function processText() {
const text = document.getElementById('textInput').value.trim();
if (!text) {
alert('请输入要分段的文本内容!');
return;
}
// 显示加载状态
document.getElementById('loading').style.display = 'block';
document.getElementById('resultSection').style.display = 'none';
try {
const response = await fetch('/segment', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ text: text })
});
if (!response.ok) {
throw new Error('网络请求失败');
}
const result = await response.json();
displayResults(result);
} catch (error) {
alert('处理失败:' + error.message);
} finally {
document.getElementById('loading').style.display = 'none';
}
}
function displayResults(result) {
// 显示统计信息
const statsHtml = `
<div class="stat-card">
<div class="stat-number">${result.total_sentences}</div>
<div class="stat-label">总句子数</div>
</div>
<div class="stat-card">
<div class="stat-number">${result.total_segments}</div>
<div class="stat-label">分段数量</div>
</div>
<div class="stat-card">
<div class="stat-number">${(result.split_decisions || []).filter(d => d.should_split).length}</div>
<div class="stat-label">分段点数</div>
</div>
`;
document.getElementById('stats').innerHTML = statsHtml;
// 显示分段结果
const segmentsHtml = result.segments.map((segment, index) => `
<div class="segment">
<div class="segment-header">段落 ${index + 1}</div>
<div class="segment-content">${segment}</div>
</div>
`).join('');
document.getElementById('segments').innerHTML = segmentsHtml;
// 显示详细分析
if (result.split_decisions && result.split_decisions.length > 0) {
const analysisHtml = result.split_decisions.map((decision, index) => `
<div class="split-decision ${decision.should_split ? 'split' : 'no-split'}">
<strong>句子对 ${index + 1}:</strong><br>
<div style="margin: 5px 0;">
<strong>句子1:</strong> ${decision.sentence1}<br>
<strong>句子2:</strong> ${decision.sentence2}
</div>
<div style="margin: 5px 0;">
<strong>决策:</strong> ${decision.should_split ? '🔴 需要分段' : '🟢 无需分段'}<br>
<strong>置信度:</strong> ${(decision.confidence * 100).toFixed(1)}%<br>
<strong>边界分数:</strong> ${(decision.boundary_score * 100).toFixed(1)}%<br>
<strong>原因:</strong> ${decision.split_reason}
</div>
</div>
`).join('');
document.getElementById('detailedAnalysis').innerHTML = analysisHtml;
}
document.getElementById('resultSection').style.display = 'block';
document.getElementById('resultSection').scrollIntoView({ behavior: 'smooth' });
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.post("/segment", response_model=DetailedSegmentResult)
async def segment_text_api(input_data: TextInput):
"""文本分段API"""
try:
if not input_data.text.strip():
raise HTTPException(status_code=400, detail="输入文本不能为空")
result = segment_text(input_data.text)
return DetailedSegmentResult(**result)
except Exception as e:
logger.error(f"文本分段处理失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@app.post("/segment/batch")
async def segment_batch_api(input_data: BatchTextInput):
"""批量文本分段API"""
try:
if not input_data.texts:
raise HTTPException(status_code=400, detail="输入文本列表不能为空")
results = []
for i, text in enumerate(input_data.texts):
if text.strip():
result = segment_text(text)
result['text_index'] = i
results.append(result)
return {"results": results, "total_processed": len(results)}
except Exception as e:
logger.error(f"批量文本分段处理失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@app.get("/health")
async def health_check():
"""健康检查API"""
return {
"status": "healthy",
"model_loaded": model is not None,
"device": str(device) if device else "unknown",
"timestamp": datetime.now().isoformat()
}
@app.get("/model/info")
async def model_info():
"""模型信息API"""
if model is None:
raise HTTPException(status_code=503, detail="模型未加载")
try:
total_params = sum(p.numel() for p in model.parameters())
return {
"model_type": "DualPathBoundaryClassifier",
"total_parameters": total_params,
"device": str(device),
"boundary_force_weight": float(model.boundary_force_weight.data),
"vocab_size": len(tokenizer.vocab),
"max_length": 512
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取模型信息失败: {str(e)}")
@app.post("/predict/pair")
async def predict_sentence_pair(sentence1: str, sentence2: str):
"""预测单个句子对是否需要分段"""
try:
if not sentence1.strip() or not sentence2.strip():
raise HTTPException(status_code=400, detail="句子不能为空")
# 使用现有的预测函数
decisions = predict_sentence_pairs([sentence1, sentence2])
if decisions:
decision = decisions[0]
return {
"sentence1": sentence1,
"sentence2": sentence2,
"should_split": decision['should_split'],
"confidence": decision['confidence'],
"boundary_score": decision['boundary_score'],
"split_reason": decision['split_reason']
}
else:
raise HTTPException(status_code=500, detail="预测失败")
except HTTPException:
raise
except Exception as e:
logger.error(f"句子对预测失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"预测失败: {str(e)}")
def main():
"""启动FastAPI服务"""
logger.info("🚀 启动双路径边界分类器FastAPI服务")
logger.info(f"📝 服务地址: http://0.0.0.0:8888")
logger.info(f"🌐 前端界面: http://0.0.0.0:8888")
logger.info(f"📚 API文档: http://0.0.0.0:8888/docs")
# 启动服务
uvicorn.run(
app,
host="0.0.0.0",
port=8888,
log_level="info",
access_log=True
)
if __name__ == "__main__":
main()

656
5.AI标注-model_trian/无标点符号训练/Fastapi.py

@ -0,0 +1,656 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import re
import time
import concurrent.futures
from typing import Dict, List, Optional, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import BertTokenizer, BertModel
import uvicorn
# ========== 双路径边界分类器模型定义部分 ==========
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class DualPathBoundaryClassifier(nn.Module):
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# 双路径分类器
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# 边界强制权重
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
nn.init.normal_(self.regular_classifier.weight, std=0.02)
nn.init.zeros_(self.regular_classifier.bias)
nn.init.normal_(self.boundary_classifier.weight, std=0.02)
nn.init.zeros_(self.boundary_classifier.bias)
nn.init.normal_(self.boundary_detector.weight, std=0.02)
nn.init.zeros_(self.boundary_detector.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# 缩放点积注意力增强
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# 双路径分类
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# 边界检测
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# 动态融合
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
loss = None
if labels is not None:
regular_loss = self.focal_loss(regular_logits, labels)
boundary_loss = self.focal_loss(boundary_logits, labels)
final_loss = self.focal_loss(final_logits, labels)
boundary_labels = self._generate_boundary_labels(labels)
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels)
total_loss = (0.4 * final_loss +
0.3 * regular_loss +
0.2 * boundary_loss +
0.1 * detection_loss)
loss = total_loss
return {
'loss': loss,
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def _generate_boundary_labels(self, labels):
boundary_labels = labels.float()
noise = torch.rand_like(boundary_labels) * 0.1
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0)
return boundary_labels
# ========== 模型加载部分 ==========
def check_gpu_availability():
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
print(f"🚀 GPU: {gpu_name} ({gpu_memory:.1f} GB)")
return torch.device('cuda')
else:
print("🔄 使用CPU")
return torch.device('cpu')
def safe_convert_focal_alpha(focal_alpha, device):
if focal_alpha is None:
return None
try:
if isinstance(focal_alpha, torch.Tensor):
return focal_alpha.to(device)
elif isinstance(focal_alpha, (list, tuple)):
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
elif isinstance(focal_alpha, np.ndarray):
return torch.from_numpy(focal_alpha).float().to(device)
elif isinstance(focal_alpha, (int, float)):
return torch.tensor([focal_alpha], dtype=torch.float32).to(device)
else:
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
except Exception as e:
print(f" 转换focal_alpha失败: {e}")
return None
def load_trained_dual_path_model(model_path, device, original_roberta_path=None):
try:
if original_roberta_path and os.path.exists(original_roberta_path):
tokenizer_path = original_roberta_path
else:
tokenizer_path = model_path
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
focal_gamma = config.get('focal_gamma', 3.0)
focal_alpha_raw = config.get('focal_alpha', None)
boundary_force_weight = config.get('boundary_force_weight', 2.0)
else:
focal_gamma = 3.0
focal_alpha_raw = None
boundary_force_weight = 2.0
focal_alpha = safe_convert_focal_alpha(focal_alpha_raw, device)
if original_roberta_path and os.path.exists(original_roberta_path):
model_init_path = original_roberta_path
else:
model_init_path = tokenizer_path
model = DualPathBoundaryClassifier(
model_path=model_init_path,
num_labels=2,
dropout=0.1,
focal_alpha=focal_alpha,
focal_gamma=focal_gamma,
boundary_force_weight=boundary_force_weight
)
model_vocab_size = model.roberta.embeddings.word_embeddings.weight.shape[0]
if model_vocab_size != tokenizer.vocab_size:
raise ValueError(f"词汇表大小不匹配: 模型={model_vocab_size}, Tokenizer={tokenizer.vocab_size}")
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
try:
state_dict = torch.load(model_weights_path, map_location=device)
except Exception:
state_dict = torch.load(model_weights_path, map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys or unexpected_keys:
print(f" 加载权重时有差异: 缺少{len(missing_keys)}个键, 多余{len(unexpected_keys)}个键")
else:
raise FileNotFoundError(f"未找到模型权重文件: {model_weights_path}")
model.to(device)
model.eval()
print("✅ 双路径边界分类器模型加载完成")
return model, tokenizer
except Exception as e:
print(f"❌ 双路径模型加载失败: {str(e)}")
raise
# ========== 文本处理部分 ==========
def split_text_into_sentences(text: str) -> List[str]:
text = text.strip()
sentence_endings = r'([。!?;])'
parts = re.split(sentence_endings, text)
sentences = []
for i in range(0, len(parts), 2):
if i < len(parts):
sentence = parts[i].strip()
if sentence:
if i + 1 < len(parts):
sentence += parts[i + 1]
sentences.append(sentence)
return sentences
def predict_sentence_pairs_dual_path(sentences: List[str], model, tokenizer, device, max_length=384) -> Dict[str, str]:
if len(sentences) < 2:
return {"paragraph_1": sentences[0] if sentences else ""}
results = {}
current_paragraph_sentences = [sentences[0]]
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1_clean = re.sub(r'[。!?;]$', '', sentences[i])
sentence2_clean = re.sub(r'[。!?;]$', '', sentences[i + 1])
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
prediction = torch.argmax(logits, dim=-1).item()
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
return results
def process_single_broadcast_dual_path(text: str, broadcast_id: Optional[str] = None) -> dict:
try:
if not text or not text.strip():
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": "文本为空"
}
sentences = split_text_into_sentences(text)
if len(sentences) == 0:
return {
"broadcast_id": broadcast_id,
"segments": {"paragraph_1": text},
"status": "success"
}
segments = predict_sentence_pairs_dual_path(sentences, model, tokenizer, device)
return {
"broadcast_id": broadcast_id,
"segments": segments,
"status": "success"
}
except Exception as e:
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": str(e)
}
# ========== FastAPI应用部分 ==========
app = FastAPI(title="双路径边界分类器文本分段服务", version="3.0.0")
# 全局变量存储模型
model = None
tokenizer = None
device = None
# ========== 请求和响应模型 ==========
class TextInput(BaseModel):
广播内容: str
class BroadcastItem(BaseModel):
广播内容: str
广播ID: Optional[str] = None
BatchInput = List[BroadcastItem]
# ========== 生命周期事件 ==========
@app.on_event("startup")
async def load_model():
global model, tokenizer, device
model_path = "/work/model_robert/model_train-eval"
original_roberta_path = "/work/model_robert/model"
print("🚀 正在启动双路径边界分类器文本分段服务...")
try:
device = check_gpu_availability()
model, tokenizer = load_trained_dual_path_model(model_path, device, original_roberta_path)
print(f"✅ 双路径边界分类器模型加载成功! 设备: {device}")
# print(f"📝 词汇表大小: {tokenizer.vocab_size}")
# print(f"🎯 序列最大长度: 384 tokens")
except Exception as e:
print(f"❌ 双路径模型加载失败: {e}")
raise
# ========== API接口 ==========
@app.get("/")
async def root():
"""根路径 - 服务状态检查"""
return {
"service": "双路径边界分类器文本分段服务",
"status": "运行中" if model is not None else "模型未加载",
"version": "3.0.0",
"model_type": "DualPathBoundaryClassifier",
"features": [
"双路径架构: 常规分类器 + 边界分类器",
"边界检测器: 纯神经网络学习边界模式",
"动态权重融合: 自适应边界识别",
"序列长度: 384 tokens",
"单条处理",
"批量处理(简化)",
"详细分析"
]
}
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {
"status": "healthy" if model is not None else "unhealthy",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"max_sequence_length": 384,
"boundary_detection": "pure_neural_network"
}
@app.post("/segment_simple")
async def segment_text_simple(input_data: TextInput):
"""单条文本分段接口"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
result = process_single_broadcast_dual_path(input_data.广播内容)
if result["status"] == "success":
return result["segments"]
else:
return {"error": result["error"]}
except Exception as e:
return {"error": f"双路径处理失败: {str(e)}"}
@app.post("/segment_batch_simple")
async def segment_batch_simple(broadcasts: BatchInput):
"""批量文本分段接口 - 简化输出"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
if not broadcasts:
return {"error": "广播列表不能为空"}
start_time = time.time()
results = []
# 并行处理
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
future_to_broadcast = {
executor.submit(process_single_broadcast_dual_path,
broadcast.广播内容,
broadcast.广播ID): broadcast
for broadcast in broadcasts
}
for future in concurrent.futures.as_completed(future_to_broadcast):
try:
result = future.result()
results.append(result)
except Exception as e:
broadcast = future_to_broadcast[future]
results.append({
"broadcast_id": broadcast.广播ID,
"status": "failed",
"error": f"双路径处理异常: {str(e)}"
})
# 简化输出格式
simplified_results = {}
success_count = 0
for i, result in enumerate(results):
key = result.get("broadcast_id") or f"broadcast_{i + 1}"
if result["status"] == "success":
simplified_results[key] = result["segments"]
success_count += 1
else:
simplified_results[key] = {"error": result.get("error", "双路径处理失败")}
total_time = time.time() - start_time
return {
"model": "DualPathBoundaryClassifier",
"total": len(results),
"success": success_count,
"failed": len(results) - success_count,
"processing_time": round(total_time, 3),
"results": simplified_results
}
except Exception as e:
return {"error": f"双路径批量处理失败: {str(e)}"}
@app.post("/segment_with_details")
async def segment_with_details(input_data: TextInput):
"""带详细信息的文本分段接口"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
text = input_data.广播内容
sentences = split_text_into_sentences(text)
if len(sentences) < 2:
return {
"segments": {"paragraph_1": text},
"sentence_details": [],
"total_sentences": len(sentences)
}
results = {}
current_paragraph_sentences = [sentences[0]]
sentence_details = []
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1_clean = re.sub(r'[。!?;]$', '', sentences[i])
sentence2_clean = re.sub(r'[。!?;]$', '', sentences[i + 1])
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=384,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
regular_logits = outputs['regular_logits']
boundary_logits = outputs['boundary_logits']
boundary_score = outputs['boundary_score'].item()
prediction = torch.argmax(logits, dim=-1).item()
probability = torch.softmax(logits, dim=-1)
regular_prob = torch.softmax(regular_logits, dim=-1)
boundary_prob = torch.softmax(boundary_logits, dim=-1)
detail = {
"sentence_pair_index": i + 1,
"sentence1": sentences[i],
"sentence2": sentences[i + 1],
"prediction": prediction,
"prediction_label": "same_paragraph" if prediction == 0 else "different_paragraph",
"final_probabilities": {
"same_paragraph": float(probability[0][0]),
"different_paragraph": float(probability[0][1])
},
"regular_path_probabilities": {
"same_paragraph": float(regular_prob[0][0]),
"different_paragraph": float(regular_prob[0][1])
},
"boundary_path_probabilities": {
"same_paragraph": float(boundary_prob[0][0]),
"different_paragraph": float(boundary_prob[0][1])
},
"boundary_score": boundary_score,
"boundary_confidence": "high" if boundary_score > 0.7 else "medium" if boundary_score > 0.3 else "low"
}
sentence_details.append(detail)
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
boundary_scores = [d["boundary_score"] for d in sentence_details]
return {
"segments": results,
"sentence_details": sentence_details,
"total_sentences": len(sentences),
"total_pairs_analyzed": len(sentence_details),
"boundary_statistics": {
"average_boundary_score": round(sum(boundary_scores) / len(boundary_scores),
4) if boundary_scores else 0,
"max_boundary_score": round(max(boundary_scores), 4) if boundary_scores else 0,
"min_boundary_score": round(min(boundary_scores), 4) if boundary_scores else 0
}
}
except Exception as e:
return {"error": f"详细分析失败: {str(e)}"}
@app.get("/stats")
async def get_processing_stats():
"""获取处理统计信息"""
return {
"service_status": "running" if model is not None else "down",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"vocab_size": tokenizer.vocab_size if tokenizer else None,
"max_sequence_length": 384,
"api_endpoints": [
"/segment_simple - 单条处理",
"/segment_batch_simple - 批量处理(简化)",
"/segment_with_details - 带详细信息分段"
]
}
# ========== 启动配置 ==========
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8888)

1431
5.AI标注-model_trian/无标点符号训练/train-robert-wwm-ext.py

File diff suppressed because it is too large Load Diff

67
6.model_train-test/API/API-test.py

@ -0,0 +1,67 @@
import json
import requests
import os
# 配置文件路径
test_file_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\AI标注\test.json"
output_dir = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\AI标注\test01"
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
# 读取测试数据
with open(test_file_path, 'r', encoding='utf-8') as f:
test_data = json.load(f)
print(f"📁 加载测试数据: {len(test_data)} 条记录")
# 服务地址
url = "http://localhost:8888/segment_batch_simple"
# 准备请求数据 - 直接使用原始格式
broadcasts = test_data
print(f"🚀 开始调用双路径边界分类器批量分段接口...")
# 发送请求
try:
response = requests.post(url, json=broadcasts)
if response.status_code == 200:
result = response.json()
print("✅ 批量分段成功!")
print(f"模型: {result['model']}")
print(f"总计: {result['total']}")
print(f"成功: {result['success']}")
print(f"失败: {result['failed']}")
print(f"处理时间: {result['processing_time']}")
print("\n📝 分段结果:")
print("=" * 80)
for broadcast_id, segments in result['results'].items():
print(f"\n📻 {broadcast_id}:")
if 'error' in segments:
print(f"❌ 错误: {segments['error']}")
else:
for para_key, para_content in segments.items():
print(f" {para_key}: {para_content}")
# 保存结果到文件
output_file = os.path.join(output_dir, "batch_segment_results.json")
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(result, f, ensure_ascii=False, indent=2)
print(f"\n💾 结果已保存到: {output_file}")
else:
print(f"❌ 请求失败: HTTP {response.status_code}")
print(response.text)
except requests.exceptions.ConnectionError:
print("❌ 连接失败,请确保双路径边界分类器服务正在运行")
print(" 启动命令: python simplified_dual_path_boundary_classifier_api.py")
print(" 服务地址: http://localhost:8888")
except Exception as e:
print(f"❌ 调用失败: {e}")

664
6.model_train-test/API/Fastapi.py

@ -0,0 +1,664 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import re
import time
import concurrent.futures
from typing import Dict, List, Optional, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import BertTokenizer, BertModel
import uvicorn
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class DualPathBoundaryClassifier(nn.Module):
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# 双路径分类器
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# 边界强制权重
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
nn.init.normal_(self.regular_classifier.weight, std=0.02)
nn.init.zeros_(self.regular_classifier.bias)
nn.init.normal_(self.boundary_classifier.weight, std=0.02)
nn.init.zeros_(self.boundary_classifier.bias)
nn.init.normal_(self.boundary_detector.weight, std=0.02)
nn.init.zeros_(self.boundary_detector.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# 缩放点积注意力增强
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# 双路径分类
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# 边界检测
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# 动态融合
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
loss = None
if labels is not None:
regular_loss = self.focal_loss(regular_logits, labels)
boundary_loss = self.focal_loss(boundary_logits, labels)
final_loss = self.focal_loss(final_logits, labels)
boundary_labels = self._generate_boundary_labels(labels)
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels)
total_loss = (0.4 * final_loss +
0.3 * regular_loss +
0.2 * boundary_loss +
0.1 * detection_loss)
loss = total_loss
return {
'loss': loss,
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def _generate_boundary_labels(self, labels):
boundary_labels = labels.float()
noise = torch.rand_like(boundary_labels) * 0.1
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0)
return boundary_labels
def check_gpu_availability():
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
print(f"🚀 GPU: {gpu_name} ({gpu_memory:.1f} GB)")
return torch.device('cuda')
else:
print("🔄 使用CPU")
return torch.device('cpu')
def safe_convert_focal_alpha(focal_alpha, device):
if focal_alpha is None:
return None
try:
if isinstance(focal_alpha, torch.Tensor):
return focal_alpha.to(device)
elif isinstance(focal_alpha, (list, tuple)):
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
elif isinstance(focal_alpha, np.ndarray):
return torch.from_numpy(focal_alpha).float().to(device)
elif isinstance(focal_alpha, (int, float)):
return torch.tensor([focal_alpha], dtype=torch.float32).to(device)
else:
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
except Exception as e:
print(f" 转换focal_alpha失败: {e}")
return None
def load_trained_dual_path_model(model_path, device, original_roberta_path=None):
try:
if original_roberta_path and os.path.exists(original_roberta_path):
tokenizer_path = original_roberta_path
else:
tokenizer_path = model_path
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
focal_gamma = config.get('focal_gamma', 3.0)
focal_alpha_raw = config.get('focal_alpha', None)
boundary_force_weight = config.get('boundary_force_weight', 2.0)
else:
focal_gamma = 3.0
focal_alpha_raw = None
boundary_force_weight = 2.0
focal_alpha = safe_convert_focal_alpha(focal_alpha_raw, device)
if original_roberta_path and os.path.exists(original_roberta_path):
model_init_path = original_roberta_path
else:
model_init_path = tokenizer_path
model = DualPathBoundaryClassifier(
model_path=model_init_path,
num_labels=2,
dropout=0.1,
focal_alpha=focal_alpha,
focal_gamma=focal_gamma,
boundary_force_weight=boundary_force_weight
)
model_vocab_size = model.roberta.embeddings.word_embeddings.weight.shape[0]
if model_vocab_size != tokenizer.vocab_size:
raise ValueError(f"词汇表大小不匹配: 模型={model_vocab_size}, Tokenizer={tokenizer.vocab_size}")
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
try:
state_dict = torch.load(model_weights_path, map_location=device)
except Exception:
state_dict = torch.load(model_weights_path, map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys or unexpected_keys:
print(f" 加载权重时有差异: 缺少{len(missing_keys)}个键, 多余{len(unexpected_keys)}个键")
else:
raise FileNotFoundError(f"未找到模型权重文件: {model_weights_path}")
model.to(device)
model.eval()
print(" 双路径边界分类器模型加载完成")
return model, tokenizer
except Exception as e:
print(f"❌ 双路径模型加载失败: {str(e)}")
raise
def split_text_into_sentences(text: str) -> List[str]:
text = text.strip()
sentence_endings = r'([。!?;])'
parts = re.split(sentence_endings, text)
sentences = []
for i in range(0, len(parts), 2):
if i < len(parts):
sentence = parts[i].strip()
if sentence:
if i + 1 < len(parts):
sentence += parts[i + 1]
sentences.append(sentence)
return sentences
def predict_sentence_pairs_dual_path(sentences: List[str], model, tokenizer, device, max_length=384) -> Dict[str, str]:
if len(sentences) < 2:
return {"paragraph_1": sentences[0] if sentences else ""}
results = {}
current_paragraph_sentences = [sentences[0]]
with torch.no_grad():
for i in range(len(sentences) - 1):
# 修改:保留标点符号,与文件一保持一致
sentence1_clean = sentences[i] # 不再移除标点
sentence2_clean = sentences[i + 1] # 不再移除标点
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
prediction = torch.argmax(logits, dim=-1).item()
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
return results
def process_single_broadcast_dual_path(text: str, broadcast_id: Optional[str] = None) -> dict:
try:
if not text or not text.strip():
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": "文本为空"
}
sentences = split_text_into_sentences(text)
if len(sentences) == 0:
return {
"broadcast_id": broadcast_id,
"segments": {"paragraph_1": text},
"status": "success"
}
segments = predict_sentence_pairs_dual_path(sentences, model, tokenizer, device)
return {
"broadcast_id": broadcast_id,
"segments": segments,
"status": "success"
}
except Exception as e:
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": str(e)
}
# ========== FastAPI应用部分 ==========
app = FastAPI(title="双路径边界分类器文本分段服务", version="3.1.0")
# 全局变量存储模型
model = None
tokenizer = None
device = None
# ========== 请求和响应模型 ==========
class TextInput(BaseModel):
广播内容: str
class BroadcastItem(BaseModel):
广播内容: str
广播ID: Optional[str] = None
BatchInput = List[BroadcastItem]
# ========== 生命周期事件 ==========
@app.on_event("startup")
async def load_model():
global model, tokenizer, device
model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train-NN"
original_roberta_path = r"D:\workstation\chinese-roberta-wwm-ext\model"
print("🚀 正在启动双路径边界分类器文本分段服务...")
try:
device = check_gpu_availability()
model, tokenizer = load_trained_dual_path_model(model_path, device, original_roberta_path)
# print(f"✅ 双路径边界分类器模型加载成功! 设备: {device}")
# print("📝 修改说明: 已保留标点符号处理,与测试脚本保持一致")
except Exception as e:
print(f"❌ 双路径模型加载失败: {e}")
raise
@app.get("/")
async def root():
return {
"service": "双路径边界分类器文本分段服务",
"status": "运行中" if model is not None else "模型未加载",
"version": "3.1.0",
"model_type": "DualPathBoundaryClassifier",
"updates": [
"v3.1.0: 修复标点符号处理不一致问题",
"保留句末标点符号,与测试脚本保持一致",
"提高分段准确性"
],
"features": [
"双路径架构: 常规分类器 + 边界分类器",
"边界检测器: 纯神经网络学习边界模式",
"动态权重融合: 自适应边界识别",
"序列长度: 384 tokens",
"标点符号保留: 保持训练-推理一致性",
"单条处理",
"批量处理(简化)",
"详细分析"
]
}
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {
"status": "healthy" if model is not None else "unhealthy",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"max_sequence_length": 384,
"boundary_detection": "pure_neural_network",
"punctuation_handling": "preserved",
"version": "3.1.0"
}
@app.post("/segment_simple")
async def segment_text_simple(input_data: TextInput):
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
result = process_single_broadcast_dual_path(input_data.广播内容)
if result["status"] == "success":
return result["segments"]
else:
return {"error": result["error"]}
except Exception as e:
return {"error": f"双路径处理失败: {str(e)}"}
@app.post("/segment_batch_simple")
async def segment_batch_simple(broadcasts: BatchInput):
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
if not broadcasts:
return {"error": "广播列表不能为空"}
start_time = time.time()
results = []
# 并行处理
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
future_to_broadcast = {
executor.submit(process_single_broadcast_dual_path,
broadcast.广播内容,
broadcast.广播ID): broadcast
for broadcast in broadcasts
}
for future in concurrent.futures.as_completed(future_to_broadcast):
try:
result = future.result()
results.append(result)
except Exception as e:
broadcast = future_to_broadcast[future]
results.append({
"broadcast_id": broadcast.广播ID,
"status": "failed",
"error": f"双路径处理异常: {str(e)}"
})
# 简化输出格式
simplified_results = {}
success_count = 0
for i, result in enumerate(results):
key = result.get("broadcast_id") or f"broadcast_{i + 1}"
if result["status"] == "success":
simplified_results[key] = result["segments"]
success_count += 1
else:
simplified_results[key] = {"error": result.get("error", "双路径处理失败")}
total_time = time.time() - start_time
return {
"model": "DualPathBoundaryClassifier",
"version": "3.1.0",
"total": len(results),
"success": success_count,
"failed": len(results) - success_count,
"processing_time": round(total_time, 3),
"results": simplified_results
}
except Exception as e:
return {"error": f"双路径批量处理失败: {str(e)}"}
@app.post("/segment_with_details")
async def segment_with_details(input_data: TextInput):
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
text = input_data.广播内容
sentences = split_text_into_sentences(text)
if len(sentences) < 2:
return {
"segments": {"paragraph_1": text},
"sentence_details": [],
"total_sentences": len(sentences),
"version": "3.1.0",
"punctuation_handling": "preserved"
}
results = {}
current_paragraph_sentences = [sentences[0]]
sentence_details = []
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1_clean = sentences[i] # 不再移除标点
sentence2_clean = sentences[i + 1] # 不再移除标点
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=384,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
regular_logits = outputs['regular_logits']
boundary_logits = outputs['boundary_logits']
boundary_score = outputs['boundary_score'].item()
prediction = torch.argmax(logits, dim=-1).item()
probability = torch.softmax(logits, dim=-1)
regular_prob = torch.softmax(regular_logits, dim=-1)
boundary_prob = torch.softmax(boundary_logits, dim=-1)
detail = {
"sentence_pair_index": i + 1,
"sentence1": sentences[i],
"sentence2": sentences[i + 1],
"sentence1_input": sentence1_clean, # 显示实际输入模型的文本
"sentence2_input": sentence2_clean, # 显示实际输入模型的文本
"prediction": prediction,
"prediction_label": "same_paragraph" if prediction == 0 else "different_paragraph",
"final_probabilities": {
"same_paragraph": float(probability[0][0]),
"different_paragraph": float(probability[0][1])
},
"regular_path_probabilities": {
"same_paragraph": float(regular_prob[0][0]),
"different_paragraph": float(regular_prob[0][1])
},
"boundary_path_probabilities": {
"same_paragraph": float(boundary_prob[0][0]),
"different_paragraph": float(boundary_prob[0][1])
},
"boundary_score": boundary_score,
"boundary_confidence": "high" if boundary_score > 0.7 else "medium" if boundary_score > 0.3 else "low"
}
sentence_details.append(detail)
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
boundary_scores = [d["boundary_score"] for d in sentence_details]
return {
"segments": results,
"sentence_details": sentence_details,
"total_sentences": len(sentences),
"total_pairs_analyzed": len(sentence_details),
"boundary_statistics": {
"average_boundary_score": round(sum(boundary_scores) / len(boundary_scores),
4) if boundary_scores else 0,
"max_boundary_score": round(max(boundary_scores), 4) if boundary_scores else 0,
"min_boundary_score": round(min(boundary_scores), 4) if boundary_scores else 0
},
"version": "3.1.0",
"punctuation_handling": "preserved",
"processing_note": "句末标点符号已保留,与训练数据保持一致"
}
except Exception as e:
return {"error": f"详细分析失败: {str(e)}"}
@app.get("/stats")
async def get_processing_stats():
"""获取处理统计信息"""
return {
"service_status": "running" if model is not None else "down",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"vocab_size": tokenizer.vocab_size if tokenizer else None,
"max_sequence_length": 384,
"version": "3.1.0",
"punctuation_handling": "preserved",
"consistency": "aligned_with_training_data",
"api_endpoints": [
"/segment_simple - 单条处理",
"/segment_batch_simple - 批量处理(简化)",
"/segment_with_details - 带详细信息分段"
]
}
# ========== 启动配置 ==========
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8888)

247
6.model_train-test/API/test-合并.py

@ -0,0 +1,247 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
根据source_id和label标签合并段落并输出txt文件
将label=0的连续句子合并label=1作为分界点分段
"""
import json
import os
from collections import defaultdict
from typing import List, Dict, Any
def load_test_data(file_path: str) -> List[Dict[str, Any]]:
"""加载测试数据"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"成功加载 {len(data)} 条数据")
return data
except FileNotFoundError:
print(f"错误:找不到文件 {file_path}")
return []
except json.JSONDecodeError as e:
print(f"错误:JSON格式错误 - {e}")
return []
except Exception as e:
print(f"错误:加载文件时出现问题 - {e}")
return []
def group_by_source_id(data: List[Dict[str, Any]]) -> Dict[str, List[Dict[str, Any]]]:
"""按source_id分组数据"""
grouped_data = defaultdict(list)
for item in data:
source_id = str(item.get('source_id', 'unknown'))
grouped_data[source_id].append(item)
# 对每个组内的数据按出现顺序排序(保持原有顺序)
for source_id in grouped_data:
# 如果数据中有索引信息,可以按索引排序
# 这里假设数据已经按正确顺序排列
pass
print(f"数据按source_id分组完成,共 {len(grouped_data)} 个组")
for source_id, items in grouped_data.items():
print(f" Source ID {source_id}: {len(items)} 条数据")
return dict(grouped_data)
def merge_paragraphs_by_labels(sentence_pairs: List[Dict[str, Any]]) -> List[str]:
"""
根据label合并段落
label=0: 同一段落需要合并
label=1: 不同段落作为分界点
"""
if not sentence_pairs:
return []
paragraphs = []
current_paragraph = []
# 处理第一个句子
if sentence_pairs:
current_paragraph.append(sentence_pairs[0]['sentence1'])
# 遍历所有句子对
for i, pair in enumerate(sentence_pairs):
sentence2 = pair['sentence2']
label = pair['label']
if label == 0:
# 同一段落,继续添加到当前段落
# 只添加sentence2,因为sentence1已经在上一轮添加过了
if sentence2 not in current_paragraph: # 避免重复
current_paragraph.append(sentence2)
elif label == 1:
# 不同段落,结束当前段落,开始新段落
if current_paragraph:
paragraph_text = ''.join(current_paragraph)
if paragraph_text.strip(): # 确保段落不为空
paragraphs.append(paragraph_text.strip())
# 开始新段落
current_paragraph = [sentence2]
# 处理最后一个段落
if current_paragraph:
paragraph_text = ''.join(current_paragraph)
if paragraph_text.strip():
paragraphs.append(paragraph_text.strip())
return paragraphs
def process_single_source(source_id: str, sentence_pairs: List[Dict[str, Any]]) -> Dict[str, Any]:
"""处理单个source_id的数据"""
print(f"\n处理Source ID: {source_id}")
print(f"句子对数量: {len(sentence_pairs)}")
# 统计标签分布
label_counts = defaultdict(int)
for pair in sentence_pairs:
label_counts[pair['label']] += 1
print(f"标签分布: Label 0 (同段): {label_counts[0]}, Label 1 (分段): {label_counts[1]}")
# 合并段落
paragraphs = merge_paragraphs_by_labels(sentence_pairs)
print(f"合并后段落数: {len(paragraphs)}")
# 统计信息
total_chars = sum(len(p) for p in paragraphs)
avg_paragraph_length = total_chars / len(paragraphs) if paragraphs else 0
return {
'source_id': source_id,
'original_pairs_count': len(sentence_pairs),
'merged_paragraphs_count': len(paragraphs),
'label_distribution': dict(label_counts),
'total_characters': total_chars,
'avg_paragraph_length': avg_paragraph_length,
'paragraphs': paragraphs
}
def save_to_txt(results: Dict[str, Dict[str, Any]], output_file: str):
"""保存结果到txt文件"""
with open(output_file, 'w', encoding='utf-8') as f:
f.write("=" * 80 + "\n")
f.write("段落合并结果\n")
f.write("根据source_id和label标签合并的段落文本\n")
f.write("=" * 80 + "\n\n")
for source_id, result in results.items():
f.write(f"【Source ID: {source_id}\n")
f.write(f"原始句子对数量: {result['original_pairs_count']}\n")
f.write(f"合并后段落数量: {result['merged_paragraphs_count']}\n")
f.write(f"标签分布: {result['label_distribution']}\n")
f.write(f"总字符数: {result['total_characters']}\n")
f.write(f"平均段落长度: {result['avg_paragraph_length']:.1f} 字符\n")
f.write("-" * 60 + "\n")
for i, paragraph in enumerate(result['paragraphs'], 1):
f.write(f"段落 {i}:\n{paragraph}\n\n")
f.write("=" * 80 + "\n\n")
def save_summary_json(results: Dict[str, Dict[str, Any]], output_file: str):
"""保存统计摘要到JSON文件"""
summary = {
'total_source_ids': len(results),
'total_original_pairs': sum(r['original_pairs_count'] for r in results.values()),
'total_merged_paragraphs': sum(r['merged_paragraphs_count'] for r in results.values()),
'total_characters': sum(r['total_characters'] for r in results.values()),
'source_details': {}
}
for source_id, result in results.items():
summary['source_details'][source_id] = {
'original_pairs_count': result['original_pairs_count'],
'merged_paragraphs_count': result['merged_paragraphs_count'],
'label_distribution': result['label_distribution'],
'total_characters': result['total_characters'],
'avg_paragraph_length': result['avg_paragraph_length']
}
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
def main():
"""主函数"""
# 配置文件路径
input_file = 'segmentation_results_from_new50.json' # 输入文件路径
output_txt = 'merged_paragraphs.txt' # 输出txt文件
output_summary = 'merge_summary.json' # 输出统计摘要
print("=" * 80)
print("段落合并处理程序")
print("根据source_id和label标签合并段落")
print("=" * 80)
# 检查输入文件
if not os.path.exists(input_file):
print(f"错误:输入文件 {input_file} 不存在!")
print("请确保test.json文件在当前目录下")
return
try:
# 1. 加载数据
data = load_test_data(input_file)
if not data:
print("没有有效数据可处理")
return
# 2. 按source_id分组
grouped_data = group_by_source_id(data)
# 3. 处理每个source_id的数据
results = {}
total_paragraphs = 0
for source_id, sentence_pairs in grouped_data.items():
result = process_single_source(source_id, sentence_pairs)
results[source_id] = result
total_paragraphs += result['merged_paragraphs_count']
# 4. 保存结果
print(f"\n保存结果...")
save_to_txt(results, output_txt)
save_summary_json(results, output_summary)
# 5. 输出总结
print("=" * 80)
print("处理完成!")
print("=" * 80)
print(f"📊 处理统计:")
print(f" 🔹 处理的Source ID数量: {len(results)}")
print(f" 🔹 原始句子对总数: {sum(r['original_pairs_count'] for r in results.values())}")
print(f" 🔹 合并后段落总数: {total_paragraphs}")
print(f" 🔹 总字符数: {sum(r['total_characters'] for r in results.values())}")
print(f"\n📁 输出文件:")
print(f" 📄 {output_txt} - 合并后的段落文本")
print(f" 📄 {output_summary} - 处理统计摘要")
print(f"\n📋 各Source ID详情:")
for source_id, result in results.items():
print(
f" Source {source_id}: {result['original_pairs_count']} 对 → {result['merged_paragraphs_count']}")
print(f"\n✅ 段落合并完成!请查看 {output_txt} 文件")
except Exception as e:
print(f"❌ 处理过程中出现错误: {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

10
6.model_train-test/API/test.json

File diff suppressed because one or more lines are too long

658
6.model_train-test/Project.py

@ -0,0 +1,658 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, precision_recall_fscore_support
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
import logging
import os
import json
from datetime import datetime
import math
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 设置matplotlib英文字体
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def check_gpu_availability():
"""检查GPU可用性"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
logger.info(f"✅ GPU Available!")
logger.info(f" 🔹 GPU Count: {gpu_count}")
logger.info(f" 🔹 GPU Model: {gpu_name}")
logger.info(f" 🔹 GPU Memory: {gpu_memory:.1f} GB")
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
return True, gpu_memory
else:
logger.info(" GPU not available, using CPU")
return False, 0
class FocalLoss(nn.Module):
"""Focal Loss for handling class imbalance"""
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot Product Attention"""
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class DualPathBoundaryClassifier(nn.Module):
"""Dual Path Boundary Classifier with Pure Neural Network Learning"""
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# Dual path classifiers
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# Boundary force weight
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
"""Initialize weights for new layers"""
nn.init.normal_(self.regular_classifier.weight, std=0.02)
nn.init.zeros_(self.regular_classifier.bias)
nn.init.normal_(self.boundary_classifier.weight, std=0.02)
nn.init.zeros_(self.boundary_classifier.bias)
nn.init.normal_(self.boundary_detector.weight, std=0.02)
nn.init.zeros_(self.boundary_detector.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# Enhanced with scaled attention
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# Dual path classification
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# Boundary detection
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# Dynamic fusion
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
loss = None
if labels is not None:
regular_loss = self.focal_loss(regular_logits, labels)
boundary_loss = self.focal_loss(boundary_logits, labels)
final_loss = self.focal_loss(final_logits, labels)
boundary_labels = self._generate_boundary_labels(labels)
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels)
total_loss = (0.4 * final_loss +
0.3 * regular_loss +
0.2 * boundary_loss +
0.1 * detection_loss)
loss = total_loss
return {
'loss': loss,
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def _generate_boundary_labels(self, labels):
"""Generate heuristic labels for boundary detection"""
boundary_labels = labels.float()
noise = torch.rand_like(boundary_labels) * 0.1
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0)
return boundary_labels
class SentencePairTestDataset(Dataset):
"""Sentence pair test dataset"""
def __init__(self, data, tokenizer, max_length=384):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
# Filter valid data
self.valid_data = [item for item in data if item['label'] in [0, 1]]
logger.info(f"Original test data: {len(data)} items, Valid data: {len(self.valid_data)} items")
self.sentence1_list = [item['sentence1'] for item in self.valid_data]
self.sentence2_list = [item['sentence2'] for item in self.valid_data]
self.labels = [item['label'] for item in self.valid_data]
def __len__(self):
return len(self.valid_data)
def __getitem__(self, idx):
sentence1 = str(self.sentence1_list[idx])
sentence2 = str(self.sentence2_list[idx])
label = self.labels[idx]
encoding = self.tokenizer(
sentence1,
sentence2,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long),
'sentence1': sentence1,
'sentence2': sentence2
}
def load_trained_model(model_path, tokenizer_path):
"""Load the trained dual path model"""
logger.info(f"Loading trained model from: {model_path}")
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
# Load model configuration
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
model_config = json.load(f)
logger.info(f"Model config loaded: {model_config.get('model_type', 'Unknown')}")
# Initialize model
model = DualPathBoundaryClassifier(
model_path=tokenizer_path, # Use original model path for RoBERTa base
num_labels=2,
dropout=0.1,
focal_alpha=None,
focal_gamma=3.0,
boundary_force_weight=2.0
)
# Load trained weights
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
model.load_state_dict(torch.load(model_weights_path, map_location='cpu'))
logger.info("✅ Trained model weights loaded successfully")
else:
raise FileNotFoundError(f"Model weights not found at: {model_weights_path}")
return model, tokenizer
def evaluate_model(model, test_dataloader, device, output_dir):
"""Evaluate the model on test dataset"""
model.eval()
all_predictions = []
all_labels = []
all_probabilities = []
all_boundary_scores = []
logger.info("🔍 Starting model evaluation...")
with torch.no_grad():
for batch_idx, batch in enumerate(test_dataloader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
boundary_scores = outputs['boundary_score']
# Get predictions
probabilities = torch.softmax(logits, dim=-1)
predictions = torch.argmax(logits, dim=-1)
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
all_boundary_scores.extend(boundary_scores.cpu().numpy())
if (batch_idx + 1) % 100 == 0:
logger.info(f" Processed {batch_idx + 1} batches...")
# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probabilities = np.array(all_probabilities)
all_boundary_scores = np.array(all_boundary_scores)
# Calculate metrics
accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_predictions, average=None)
# Generate detailed classification report
class_report = classification_report(all_labels, all_predictions,
target_names=['Same Paragraph (0)', 'Different Paragraph (1)'],
output_dict=True)
# Generate confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
logger.info("📊 Test Results:")
logger.info(f" Overall Accuracy: {accuracy:.4f}")
logger.info(
f" Class 0 (Same Paragraph) - Precision: {precision[0]:.4f}, Recall: {recall[0]:.4f}, F1: {f1[0]:.4f}")
logger.info(
f" Class 1 (Different Paragraph) - Precision: {precision[1]:.4f}, Recall: {recall[1]:.4f}, F1: {f1[1]:.4f}")
# Save results
results = {
'overall_accuracy': float(accuracy),
'class_metrics': {
'class_0_same_paragraph': {
'precision': float(precision[0]),
'recall': float(recall[0]),
'f1_score': float(f1[0]),
'support': int(support[0])
},
'class_1_different_paragraph': {
'precision': float(precision[1]),
'recall': float(recall[1]),
'f1_score': float(f1[1]),
'support': int(support[1])
}
},
'confusion_matrix': cm.tolist(),
'classification_report': class_report,
'test_samples_count': len(all_predictions),
'boundary_score_stats': {
'mean': float(np.mean(all_boundary_scores)),
'std': float(np.std(all_boundary_scores)),
'min': float(np.min(all_boundary_scores)),
'max': float(np.max(all_boundary_scores))
}
}
return results, cm, all_predictions, all_labels, all_probabilities, all_boundary_scores
def plot_confusion_matrix(cm, output_dir, model_name="Dual Path Boundary Classifier"):
"""Plot and save confusion matrix"""
plt.figure(figsize=(10, 8))
# Create heatmap
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'],
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'],
cbar_kws={'label': 'Number of Samples'})
plt.title(f'Confusion Matrix - {model_name}\nTest Dataset Evaluation',
fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Label', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=14, fontweight='bold')
# Add accuracy text
accuracy = np.trace(cm) / np.sum(cm)
plt.text(0.5, -0.15, f'Overall Accuracy: {accuracy:.4f}',
ha='center', transform=plt.gca().transAxes, fontsize=12,
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
# Add sample counts
total_samples = np.sum(cm)
plt.text(0.5, -0.25, f'Total Test Samples: {total_samples}',
ha='center', transform=plt.gca().transAxes, fontsize=10)
plt.tight_layout()
# Save plot
save_path = os.path.join(output_dir, 'confusion_matrix_test_results.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"📊 Confusion matrix saved: {save_path}")
return save_path
def plot_class_distribution(results, output_dir):
"""Plot class distribution and performance metrics"""
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
# Class distribution
class_0_count = results['class_metrics']['class_0_same_paragraph']['support']
class_1_count = results['class_metrics']['class_1_different_paragraph']['support']
ax1.bar(['Same Paragraph (0)', 'Different Paragraph (1)'],
[class_0_count, class_1_count],
color=['skyblue', 'lightcoral'])
ax1.set_title('Test Dataset Class Distribution', fontweight='bold')
ax1.set_ylabel('Number of Samples')
# Add count labels on bars
ax1.text(0, class_0_count + max(class_0_count, class_1_count) * 0.01, str(class_0_count),
ha='center', fontweight='bold')
ax1.text(1, class_1_count + max(class_0_count, class_1_count) * 0.01, str(class_1_count),
ha='center', fontweight='bold')
# Precision comparison
precision_0 = results['class_metrics']['class_0_same_paragraph']['precision']
precision_1 = results['class_metrics']['class_1_different_paragraph']['precision']
ax2.bar(['Same Paragraph (0)', 'Different Paragraph (1)'],
[precision_0, precision_1],
color=['lightgreen', 'orange'])
ax2.set_title('Precision by Class', fontweight='bold')
ax2.set_ylabel('Precision Score')
ax2.set_ylim(0, 1)
# Recall comparison
recall_0 = results['class_metrics']['class_0_same_paragraph']['recall']
recall_1 = results['class_metrics']['class_1_different_paragraph']['recall']
ax3.bar(['Same Paragraph (0)', 'Different Paragraph (1)'],
[recall_0, recall_1],
color=['lightcyan', 'plum'])
ax3.set_title('Recall by Class', fontweight='bold')
ax3.set_ylabel('Recall Score')
ax3.set_ylim(0, 1)
# F1-Score comparison
f1_0 = results['class_metrics']['class_0_same_paragraph']['f1_score']
f1_1 = results['class_metrics']['class_1_different_paragraph']['f1_score']
ax4.bar(['Same Paragraph (0)', 'Different Paragraph (1)'],
[f1_0, f1_1],
color=['gold', 'mediumpurple'])
ax4.set_title('F1-Score by Class', fontweight='bold')
ax4.set_ylabel('F1-Score')
ax4.set_ylim(0, 1)
plt.suptitle('Dual Path Boundary Classifier - Test Performance Analysis',
fontsize=16, fontweight='bold')
plt.tight_layout()
# Save plot
save_path = os.path.join(output_dir, 'class_performance_analysis.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"📊 Class performance analysis saved: {save_path}")
return save_path
def main():
"""Main function"""
logger.info("=" * 80)
logger.info("🚀 Dual Path Boundary Classifier - Test Evaluation")
logger.info("=" * 80)
# Configuration
original_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model"
trained_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train"
test_file = r"D:\workstation\AI标注\数据清洗+json\test_dataset.json"
output_dir = r"D:\workstation\AI标注\test"
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Check GPU availability
gpu_available, gpu_memory = check_gpu_availability()
device = torch.device('cuda' if gpu_available else 'cpu')
logger.info(f"📋 Test Configuration:")
logger.info(f" 🔹 Original Model: {original_model_path}")
logger.info(f" 🔹 Trained Model: {trained_model_path}")
logger.info(f" 🔹 Test Dataset: {test_file}")
logger.info(f" 🔹 Output Directory: {output_dir}")
logger.info(f" 🔹 Max Length: 384 tokens")
logger.info(f" 🔹 Device: {device}")
try:
# Load test data
logger.info("📥 Loading test dataset...")
with open(test_file, 'r', encoding='utf-8') as f:
test_data = json.load(f)
logger.info(f" Loaded {len(test_data)} test samples")
# Load trained model
model, tokenizer = load_trained_model(trained_model_path, original_model_path)
model = model.to(device)
# Create test dataset
test_dataset = SentencePairTestDataset(test_data, tokenizer, max_length=384)
test_dataloader = DataLoader(
test_dataset,
batch_size=32, # Optimized for RTX 4060
shuffle=False,
num_workers=4,
pin_memory=True if gpu_available else False
)
logger.info(f" Test dataset size: {len(test_dataset)}")
logger.info(f" Batch size: 32")
# Evaluate model
start_time = datetime.now()
results, cm, predictions, labels, probabilities, boundary_scores = evaluate_model(
model, test_dataloader, device, output_dir
)
end_time = datetime.now()
evaluation_time = end_time - start_time
logger.info(f" Evaluation completed in: {evaluation_time}")
# Generate visualizations
logger.info("📊 Generating visualizations...")
# Plot confusion matrix
cm_path = plot_confusion_matrix(cm, output_dir)
# Plot class performance analysis
perf_path = plot_class_distribution(results, output_dir)
# Save detailed results
results['evaluation_info'] = {
'evaluation_time': str(evaluation_time),
'device_used': str(device),
'model_type': 'DualPathBoundaryClassifier',
'max_length': 384,
'batch_size': 32,
'test_file': test_file,
'trained_model_path': trained_model_path
}
results_path = os.path.join(output_dir, 'test_results_detailed.json')
with open(results_path, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# Save predictions
predictions_data = []
for i in range(len(predictions)):
predictions_data.append({
'index': i,
'sentence1': test_dataset[i]['sentence1'],
'sentence2': test_dataset[i]['sentence2'],
'true_label': int(labels[i]),
'predicted_label': int(predictions[i]),
'probability_class_0': float(probabilities[i][0]),
'probability_class_1': float(probabilities[i][1]),
'boundary_score': float(boundary_scores[i]),
'correct': bool(labels[i] == predictions[i])
})
predictions_path = os.path.join(output_dir, 'detailed_predictions.json')
with open(predictions_path, 'w', encoding='utf-8') as f:
json.dump(predictions_data, f, ensure_ascii=False, indent=2)
# Generate summary report
summary = {
'model_info': {
'model_type': 'Dual Path Boundary Classifier',
'base_model': 'Chinese-RoBERTa-WWM-Ext',
'max_length': 384,
'trained_model_path': trained_model_path
},
'test_results': {
'overall_accuracy': results['overall_accuracy'],
'total_samples': len(predictions),
'correct_predictions': int(np.sum(labels == predictions)),
'incorrect_predictions': int(np.sum(labels != predictions))
},
'class_performance': results['class_metrics'],
'boundary_detection': results['boundary_score_stats'],
'files_generated': [
'test_results_detailed.json',
'detailed_predictions.json',
'confusion_matrix_test_results.png',
'class_performance_analysis.png',
'test_summary.json'
]
}
summary_path = os.path.join(output_dir, 'test_summary.json')
with open(summary_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
# Print final results
logger.info("=" * 80)
logger.info("🎉 Test Evaluation Completed!")
logger.info("=" * 80)
logger.info(f"📊 Final Results:")
logger.info(f" 🔹 Overall Accuracy: {results['overall_accuracy']:.4f}")
logger.info(f" 🔹 Total Test Samples: {len(predictions)}")
logger.info(f" 🔹 Correct Predictions: {np.sum(labels == predictions)}")
logger.info(f" 🔹 Evaluation Time: {evaluation_time}")
logger.info(f"\n📈 Class Performance:")
logger.info(f" Class 0 (Same Paragraph):")
logger.info(f" • Precision: {results['class_metrics']['class_0_same_paragraph']['precision']:.4f}")
logger.info(f" • Recall: {results['class_metrics']['class_0_same_paragraph']['recall']:.4f}")
logger.info(f" • F1-Score: {results['class_metrics']['class_0_same_paragraph']['f1_score']:.4f}")
logger.info(f" Class 1 (Different Paragraph):")
logger.info(f" • Precision: {results['class_metrics']['class_1_different_paragraph']['precision']:.4f}")
logger.info(f" • Recall: {results['class_metrics']['class_1_different_paragraph']['recall']:.4f}")
logger.info(f" • F1-Score: {results['class_metrics']['class_1_different_paragraph']['f1_score']:.4f}")
logger.info(f"\n📁 Generated Files in {output_dir}:")
logger.info(f" 📄 test_summary.json - Test evaluation summary")
logger.info(f" 📄 test_results_detailed.json - Detailed test results")
logger.info(f" 📄 detailed_predictions.json - Individual predictions")
logger.info(f" 📊 confusion_matrix_test_results.png - Confusion matrix")
logger.info(f" 📊 class_performance_analysis.png - Performance analysis")
logger.info(f"\n🎯 Model Performance Summary:")
logger.info(f" ✅ Dual Path Boundary Classifier successfully evaluated")
logger.info(f" ✅ Optimized for RTX 4060 with max_length=384")
logger.info(f" ✅ English visualizations generated")
logger.info(f" ✅ All results saved to: {output_dir}")
except Exception as e:
logger.error(f"❌ Error during evaluation: {str(e)}")
import traceback
traceback.print_exc()
raise
if __name__ == "__main__":
main()

BIN
6.model_train-test/test/class_performance_analysis.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 271 KiB

BIN
6.model_train-test/test/confusion_matrix_test_results.png

Binary file not shown.

After

Width:  |  Height:  |  Size: 208 KiB

52208
6.model_train-test/test/detailed_predictions.json

File diff suppressed because it is too large Load Diff

70
6.model_train-test/test/test_results_detailed.json

@ -0,0 +1,70 @@
{
"overall_accuracy": 0.9827222924568058,
"class_metrics": {
"class_0_same_paragraph": {
"precision": 0.995013061030634,
"recall": 0.9856504351917196,
"f1_score": 0.9903096194753014,
"support": 4251
},
"class_1_different_paragraph": {
"precision": 0.8859813084112149,
"recall": 0.9575757575757575,
"f1_score": 0.920388349514563,
"support": 495
}
},
"confusion_matrix": [
[
4190,
61
],
[
21,
474
]
],
"classification_report": {
"Same Paragraph (0)": {
"precision": 0.995013061030634,
"recall": 0.9856504351917196,
"f1-score": 0.9903096194753014,
"support": 4251.0
},
"Different Paragraph (1)": {
"precision": 0.8859813084112149,
"recall": 0.9575757575757575,
"f1-score": 0.920388349514563,
"support": 495.0
},
"accuracy": 0.9827222924568058,
"macro avg": {
"precision": 0.9404971847209245,
"recall": 0.9716130963837386,
"f1-score": 0.9553489844949322,
"support": 4746.0
},
"weighted avg": {
"precision": 0.9836412284249424,
"recall": 0.9827222924568058,
"f1-score": 0.9830169459332522,
"support": 4746.0
}
},
"test_samples_count": 4746,
"boundary_score_stats": {
"mean": 0.14779670536518097,
"std": 0.29127171635627747,
"min": 0.042458675801754,
"max": 0.9999679327011108
},
"evaluation_info": {
"evaluation_time": "0:01:19.859698",
"device_used": "cuda",
"model_type": "DualPathBoundaryClassifier",
"max_length": 384,
"batch_size": 32,
"test_file": "D:\\workstation\\AI标注\\数据清洗+json\\test_dataset.json",
"trained_model_path": "D:\\workstation\\chinese-roberta-wwm-ext\\model-train-eval-NN\\model_train"
}
}

41
6.model_train-test/test/test_summary.json

@ -0,0 +1,41 @@
{
"model_info": {
"model_type": "Dual Path Boundary Classifier",
"base_model": "Chinese-RoBERTa-WWM-Ext",
"max_length": 384,
"trained_model_path": "D:\\workstation\\chinese-roberta-wwm-ext\\model-train-eval-NN\\model_train"
},
"test_results": {
"overall_accuracy": 0.9827222924568058,
"total_samples": 4746,
"correct_predictions": 4664,
"incorrect_predictions": 82
},
"class_performance": {
"class_0_same_paragraph": {
"precision": 0.995013061030634,
"recall": 0.9856504351917196,
"f1_score": 0.9903096194753014,
"support": 4251
},
"class_1_different_paragraph": {
"precision": 0.8859813084112149,
"recall": 0.9575757575757575,
"f1_score": 0.920388349514563,
"support": 495
}
},
"boundary_detection": {
"mean": 0.14779670536518097,
"std": 0.29127171635627747,
"min": 0.042458675801754,
"max": 0.9999679327011108
},
"files_generated": [
"test_results_detailed.json",
"detailed_predictions.json",
"confusion_matrix_test_results.png",
"class_performance_analysis.png",
"test_summary.json"
]
}

1059
FinalData/batch_deduplication_results_8-1103.csv

File diff suppressed because one or more lines are too long

51
FinalData/batch_deduplication_results_new50.csv

File diff suppressed because one or more lines are too long

51
FinalData/merged-new50.csv

File diff suppressed because one or more lines are too long

324480
FinalData/segmentation_results_from_7-1103_retried.json

File diff suppressed because it is too large Load Diff

332390
FinalData/segmentation_results_from_7-1153_retried.json

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save