You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

664 lines
24 KiB

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import re
import time
import concurrent.futures
from typing import Dict, List, Optional, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import BertTokenizer, BertModel
import uvicorn
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class DualPathBoundaryClassifier(nn.Module):
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# 双路径分类器
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# 边界强制权重
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
nn.init.normal_(self.regular_classifier.weight, std=0.02)
nn.init.zeros_(self.regular_classifier.bias)
nn.init.normal_(self.boundary_classifier.weight, std=0.02)
nn.init.zeros_(self.boundary_classifier.bias)
nn.init.normal_(self.boundary_detector.weight, std=0.02)
nn.init.zeros_(self.boundary_detector.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# 缩放点积注意力增强
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# 双路径分类
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# 边界检测
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# 动态融合
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
loss = None
if labels is not None:
regular_loss = self.focal_loss(regular_logits, labels)
boundary_loss = self.focal_loss(boundary_logits, labels)
final_loss = self.focal_loss(final_logits, labels)
boundary_labels = self._generate_boundary_labels(labels)
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels)
total_loss = (0.4 * final_loss +
0.3 * regular_loss +
0.2 * boundary_loss +
0.1 * detection_loss)
loss = total_loss
return {
'loss': loss,
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def _generate_boundary_labels(self, labels):
boundary_labels = labels.float()
noise = torch.rand_like(boundary_labels) * 0.1
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0)
return boundary_labels
def check_gpu_availability():
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
print(f"🚀 GPU: {gpu_name} ({gpu_memory:.1f} GB)")
return torch.device('cuda')
else:
print("🔄 使用CPU")
return torch.device('cpu')
def safe_convert_focal_alpha(focal_alpha, device):
if focal_alpha is None:
return None
try:
if isinstance(focal_alpha, torch.Tensor):
return focal_alpha.to(device)
elif isinstance(focal_alpha, (list, tuple)):
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
elif isinstance(focal_alpha, np.ndarray):
return torch.from_numpy(focal_alpha).float().to(device)
elif isinstance(focal_alpha, (int, float)):
return torch.tensor([focal_alpha], dtype=torch.float32).to(device)
else:
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
except Exception as e:
print(f" 转换focal_alpha失败: {e}")
return None
def load_trained_dual_path_model(model_path, device, original_roberta_path=None):
try:
if original_roberta_path and os.path.exists(original_roberta_path):
tokenizer_path = original_roberta_path
else:
tokenizer_path = model_path
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
focal_gamma = config.get('focal_gamma', 3.0)
focal_alpha_raw = config.get('focal_alpha', None)
boundary_force_weight = config.get('boundary_force_weight', 2.0)
else:
focal_gamma = 3.0
focal_alpha_raw = None
boundary_force_weight = 2.0
focal_alpha = safe_convert_focal_alpha(focal_alpha_raw, device)
if original_roberta_path and os.path.exists(original_roberta_path):
model_init_path = original_roberta_path
else:
model_init_path = tokenizer_path
model = DualPathBoundaryClassifier(
model_path=model_init_path,
num_labels=2,
dropout=0.1,
focal_alpha=focal_alpha,
focal_gamma=focal_gamma,
boundary_force_weight=boundary_force_weight
)
model_vocab_size = model.roberta.embeddings.word_embeddings.weight.shape[0]
if model_vocab_size != tokenizer.vocab_size:
raise ValueError(f"词汇表大小不匹配: 模型={model_vocab_size}, Tokenizer={tokenizer.vocab_size}")
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
try:
state_dict = torch.load(model_weights_path, map_location=device)
except Exception:
state_dict = torch.load(model_weights_path, map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys or unexpected_keys:
print(f" 加载权重时有差异: 缺少{len(missing_keys)}个键, 多余{len(unexpected_keys)}个键")
else:
raise FileNotFoundError(f"未找到模型权重文件: {model_weights_path}")
model.to(device)
model.eval()
print(" 双路径边界分类器模型加载完成")
return model, tokenizer
except Exception as e:
print(f"❌ 双路径模型加载失败: {str(e)}")
raise
def split_text_into_sentences(text: str) -> List[str]:
text = text.strip()
sentence_endings = r'([。!?;])'
parts = re.split(sentence_endings, text)
sentences = []
for i in range(0, len(parts), 2):
if i < len(parts):
sentence = parts[i].strip()
if sentence:
if i + 1 < len(parts):
sentence += parts[i + 1]
sentences.append(sentence)
return sentences
def predict_sentence_pairs_dual_path(sentences: List[str], model, tokenizer, device, max_length=384) -> Dict[str, str]:
if len(sentences) < 2:
return {"paragraph_1": sentences[0] if sentences else ""}
results = {}
current_paragraph_sentences = [sentences[0]]
with torch.no_grad():
for i in range(len(sentences) - 1):
# 修改:保留标点符号,与文件一保持一致
sentence1_clean = sentences[i] # 不再移除标点
sentence2_clean = sentences[i + 1] # 不再移除标点
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
prediction = torch.argmax(logits, dim=-1).item()
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
return results
def process_single_broadcast_dual_path(text: str, broadcast_id: Optional[str] = None) -> dict:
try:
if not text or not text.strip():
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": "文本为空"
}
sentences = split_text_into_sentences(text)
if len(sentences) == 0:
return {
"broadcast_id": broadcast_id,
"segments": {"paragraph_1": text},
"status": "success"
}
segments = predict_sentence_pairs_dual_path(sentences, model, tokenizer, device)
return {
"broadcast_id": broadcast_id,
"segments": segments,
"status": "success"
}
except Exception as e:
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": str(e)
}
# ========== FastAPI应用部分 ==========
app = FastAPI(title="双路径边界分类器文本分段服务", version="3.1.0")
# 全局变量存储模型
model = None
tokenizer = None
device = None
# ========== 请求和响应模型 ==========
class TextInput(BaseModel):
广播内容: str
class BroadcastItem(BaseModel):
广播内容: str
广播ID: Optional[str] = None
BatchInput = List[BroadcastItem]
# ========== 生命周期事件 ==========
@app.on_event("startup")
async def load_model():
global model, tokenizer, device
model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train-NN"
original_roberta_path = r"D:\workstation\chinese-roberta-wwm-ext\model"
print("🚀 正在启动双路径边界分类器文本分段服务...")
try:
device = check_gpu_availability()
model, tokenizer = load_trained_dual_path_model(model_path, device, original_roberta_path)
# print(f"✅ 双路径边界分类器模型加载成功! 设备: {device}")
# print("📝 修改说明: 已保留标点符号处理,与测试脚本保持一致")
except Exception as e:
print(f"❌ 双路径模型加载失败: {e}")
raise
@app.get("/")
async def root():
return {
"service": "双路径边界分类器文本分段服务",
"status": "运行中" if model is not None else "模型未加载",
"version": "3.1.0",
"model_type": "DualPathBoundaryClassifier",
"updates": [
"v3.1.0: 修复标点符号处理不一致问题",
"保留句末标点符号,与测试脚本保持一致",
"提高分段准确性"
],
"features": [
"双路径架构: 常规分类器 + 边界分类器",
"边界检测器: 纯神经网络学习边界模式",
"动态权重融合: 自适应边界识别",
"序列长度: 384 tokens",
"标点符号保留: 保持训练-推理一致性",
"单条处理",
"批量处理(简化)",
"详细分析"
]
}
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {
"status": "healthy" if model is not None else "unhealthy",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"max_sequence_length": 384,
"boundary_detection": "pure_neural_network",
"punctuation_handling": "preserved",
"version": "3.1.0"
}
@app.post("/segment_simple")
async def segment_text_simple(input_data: TextInput):
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
result = process_single_broadcast_dual_path(input_data.广播内容)
if result["status"] == "success":
return result["segments"]
else:
return {"error": result["error"]}
except Exception as e:
return {"error": f"双路径处理失败: {str(e)}"}
@app.post("/segment_batch_simple")
async def segment_batch_simple(broadcasts: BatchInput):
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
if not broadcasts:
return {"error": "广播列表不能为空"}
start_time = time.time()
results = []
# 并行处理
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
future_to_broadcast = {
executor.submit(process_single_broadcast_dual_path,
broadcast.广播内容,
broadcast.广播ID): broadcast
for broadcast in broadcasts
}
for future in concurrent.futures.as_completed(future_to_broadcast):
try:
result = future.result()
results.append(result)
except Exception as e:
broadcast = future_to_broadcast[future]
results.append({
"broadcast_id": broadcast.广播ID,
"status": "failed",
"error": f"双路径处理异常: {str(e)}"
})
# 简化输出格式
simplified_results = {}
success_count = 0
for i, result in enumerate(results):
key = result.get("broadcast_id") or f"broadcast_{i + 1}"
if result["status"] == "success":
simplified_results[key] = result["segments"]
success_count += 1
else:
simplified_results[key] = {"error": result.get("error", "双路径处理失败")}
total_time = time.time() - start_time
return {
"model": "DualPathBoundaryClassifier",
"version": "3.1.0",
"total": len(results),
"success": success_count,
"failed": len(results) - success_count,
"processing_time": round(total_time, 3),
"results": simplified_results
}
except Exception as e:
return {"error": f"双路径批量处理失败: {str(e)}"}
@app.post("/segment_with_details")
async def segment_with_details(input_data: TextInput):
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
text = input_data.广播内容
sentences = split_text_into_sentences(text)
if len(sentences) < 2:
return {
"segments": {"paragraph_1": text},
"sentence_details": [],
"total_sentences": len(sentences),
"version": "3.1.0",
"punctuation_handling": "preserved"
}
results = {}
current_paragraph_sentences = [sentences[0]]
sentence_details = []
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1_clean = sentences[i] # 不再移除标点
sentence2_clean = sentences[i + 1] # 不再移除标点
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=384,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
regular_logits = outputs['regular_logits']
boundary_logits = outputs['boundary_logits']
boundary_score = outputs['boundary_score'].item()
prediction = torch.argmax(logits, dim=-1).item()
probability = torch.softmax(logits, dim=-1)
regular_prob = torch.softmax(regular_logits, dim=-1)
boundary_prob = torch.softmax(boundary_logits, dim=-1)
detail = {
"sentence_pair_index": i + 1,
"sentence1": sentences[i],
"sentence2": sentences[i + 1],
"sentence1_input": sentence1_clean, # 显示实际输入模型的文本
"sentence2_input": sentence2_clean, # 显示实际输入模型的文本
"prediction": prediction,
"prediction_label": "same_paragraph" if prediction == 0 else "different_paragraph",
"final_probabilities": {
"same_paragraph": float(probability[0][0]),
"different_paragraph": float(probability[0][1])
},
"regular_path_probabilities": {
"same_paragraph": float(regular_prob[0][0]),
"different_paragraph": float(regular_prob[0][1])
},
"boundary_path_probabilities": {
"same_paragraph": float(boundary_prob[0][0]),
"different_paragraph": float(boundary_prob[0][1])
},
"boundary_score": boundary_score,
"boundary_confidence": "high" if boundary_score > 0.7 else "medium" if boundary_score > 0.3 else "low"
}
sentence_details.append(detail)
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
boundary_scores = [d["boundary_score"] for d in sentence_details]
return {
"segments": results,
"sentence_details": sentence_details,
"total_sentences": len(sentences),
"total_pairs_analyzed": len(sentence_details),
"boundary_statistics": {
"average_boundary_score": round(sum(boundary_scores) / len(boundary_scores),
4) if boundary_scores else 0,
"max_boundary_score": round(max(boundary_scores), 4) if boundary_scores else 0,
"min_boundary_score": round(min(boundary_scores), 4) if boundary_scores else 0
},
"version": "3.1.0",
"punctuation_handling": "preserved",
"processing_note": "句末标点符号已保留,与训练数据保持一致"
}
except Exception as e:
return {"error": f"详细分析失败: {str(e)}"}
@app.get("/stats")
async def get_processing_stats():
"""获取处理统计信息"""
return {
"service_status": "running" if model is not None else "down",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"vocab_size": tokenizer.vocab_size if tokenizer else None,
"max_sequence_length": 384,
"version": "3.1.0",
"punctuation_handling": "preserved",
"consistency": "aligned_with_training_data",
"api_endpoints": [
"/segment_simple - 单条处理",
"/segment_batch_simple - 批量处理(简化)",
"/segment_with_details - 带详细信息分段"
]
}
# ========== 启动配置 ==========
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8888)