You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

656 lines
23 KiB

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import re
import time
import concurrent.futures
from typing import Dict, List, Optional, Union
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from transformers import BertTokenizer, BertModel
import uvicorn
# ========== 双路径边界分类器模型定义部分 ==========
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class FocalLoss(nn.Module):
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class DualPathBoundaryClassifier(nn.Module):
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# 双路径分类器
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# 边界强制权重
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
nn.init.normal_(self.regular_classifier.weight, std=0.02)
nn.init.zeros_(self.regular_classifier.bias)
nn.init.normal_(self.boundary_classifier.weight, std=0.02)
nn.init.zeros_(self.boundary_classifier.bias)
nn.init.normal_(self.boundary_detector.weight, std=0.02)
nn.init.zeros_(self.boundary_detector.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# 缩放点积注意力增强
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# 双路径分类
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# 边界检测
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# 动态融合
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
loss = None
if labels is not None:
regular_loss = self.focal_loss(regular_logits, labels)
boundary_loss = self.focal_loss(boundary_logits, labels)
final_loss = self.focal_loss(final_logits, labels)
boundary_labels = self._generate_boundary_labels(labels)
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels)
total_loss = (0.4 * final_loss +
0.3 * regular_loss +
0.2 * boundary_loss +
0.1 * detection_loss)
loss = total_loss
return {
'loss': loss,
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def _generate_boundary_labels(self, labels):
boundary_labels = labels.float()
noise = torch.rand_like(boundary_labels) * 0.1
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0)
return boundary_labels
# ========== 模型加载部分 ==========
def check_gpu_availability():
if torch.cuda.is_available():
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
print(f"🚀 GPU: {gpu_name} ({gpu_memory:.1f} GB)")
return torch.device('cuda')
else:
print("🔄 使用CPU")
return torch.device('cpu')
def safe_convert_focal_alpha(focal_alpha, device):
if focal_alpha is None:
return None
try:
if isinstance(focal_alpha, torch.Tensor):
return focal_alpha.to(device)
elif isinstance(focal_alpha, (list, tuple)):
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
elif isinstance(focal_alpha, np.ndarray):
return torch.from_numpy(focal_alpha).float().to(device)
elif isinstance(focal_alpha, (int, float)):
return torch.tensor([focal_alpha], dtype=torch.float32).to(device)
else:
return torch.tensor(focal_alpha, dtype=torch.float32).to(device)
except Exception as e:
print(f" 转换focal_alpha失败: {e}")
return None
def load_trained_dual_path_model(model_path, device, original_roberta_path=None):
try:
if original_roberta_path and os.path.exists(original_roberta_path):
tokenizer_path = original_roberta_path
else:
tokenizer_path = model_path
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
focal_gamma = config.get('focal_gamma', 3.0)
focal_alpha_raw = config.get('focal_alpha', None)
boundary_force_weight = config.get('boundary_force_weight', 2.0)
else:
focal_gamma = 3.0
focal_alpha_raw = None
boundary_force_weight = 2.0
focal_alpha = safe_convert_focal_alpha(focal_alpha_raw, device)
if original_roberta_path and os.path.exists(original_roberta_path):
model_init_path = original_roberta_path
else:
model_init_path = tokenizer_path
model = DualPathBoundaryClassifier(
model_path=model_init_path,
num_labels=2,
dropout=0.1,
focal_alpha=focal_alpha,
focal_gamma=focal_gamma,
boundary_force_weight=boundary_force_weight
)
model_vocab_size = model.roberta.embeddings.word_embeddings.weight.shape[0]
if model_vocab_size != tokenizer.vocab_size:
raise ValueError(f"词汇表大小不匹配: 模型={model_vocab_size}, Tokenizer={tokenizer.vocab_size}")
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
try:
state_dict = torch.load(model_weights_path, map_location=device)
except Exception:
state_dict = torch.load(model_weights_path, map_location='cpu')
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if missing_keys or unexpected_keys:
print(f" 加载权重时有差异: 缺少{len(missing_keys)}个键, 多余{len(unexpected_keys)}个键")
else:
raise FileNotFoundError(f"未找到模型权重文件: {model_weights_path}")
model.to(device)
model.eval()
print("✅ 双路径边界分类器模型加载完成")
return model, tokenizer
except Exception as e:
print(f"❌ 双路径模型加载失败: {str(e)}")
raise
# ========== 文本处理部分 ==========
def split_text_into_sentences(text: str) -> List[str]:
text = text.strip()
sentence_endings = r'([。!?;])'
parts = re.split(sentence_endings, text)
sentences = []
for i in range(0, len(parts), 2):
if i < len(parts):
sentence = parts[i].strip()
if sentence:
if i + 1 < len(parts):
sentence += parts[i + 1]
sentences.append(sentence)
return sentences
def predict_sentence_pairs_dual_path(sentences: List[str], model, tokenizer, device, max_length=384) -> Dict[str, str]:
if len(sentences) < 2:
return {"paragraph_1": sentences[0] if sentences else ""}
results = {}
current_paragraph_sentences = [sentences[0]]
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1_clean = re.sub(r'[。!?;]$', '', sentences[i])
sentence2_clean = re.sub(r'[。!?;]$', '', sentences[i + 1])
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
prediction = torch.argmax(logits, dim=-1).item()
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
return results
def process_single_broadcast_dual_path(text: str, broadcast_id: Optional[str] = None) -> dict:
try:
if not text or not text.strip():
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": "文本为空"
}
sentences = split_text_into_sentences(text)
if len(sentences) == 0:
return {
"broadcast_id": broadcast_id,
"segments": {"paragraph_1": text},
"status": "success"
}
segments = predict_sentence_pairs_dual_path(sentences, model, tokenizer, device)
return {
"broadcast_id": broadcast_id,
"segments": segments,
"status": "success"
}
except Exception as e:
return {
"broadcast_id": broadcast_id,
"segments": {},
"status": "failed",
"error": str(e)
}
# ========== FastAPI应用部分 ==========
app = FastAPI(title="双路径边界分类器文本分段服务", version="3.0.0")
# 全局变量存储模型
model = None
tokenizer = None
device = None
# ========== 请求和响应模型 ==========
class TextInput(BaseModel):
广播内容: str
class BroadcastItem(BaseModel):
广播内容: str
广播ID: Optional[str] = None
BatchInput = List[BroadcastItem]
# ========== 生命周期事件 ==========
@app.on_event("startup")
async def load_model():
global model, tokenizer, device
model_path = "/work/model_robert/model_train-eval"
original_roberta_path = "/work/model_robert/model"
print("🚀 正在启动双路径边界分类器文本分段服务...")
try:
device = check_gpu_availability()
model, tokenizer = load_trained_dual_path_model(model_path, device, original_roberta_path)
print(f"✅ 双路径边界分类器模型加载成功! 设备: {device}")
# print(f"📝 词汇表大小: {tokenizer.vocab_size}")
# print(f"🎯 序列最大长度: 384 tokens")
except Exception as e:
print(f"❌ 双路径模型加载失败: {e}")
raise
# ========== API接口 ==========
@app.get("/")
async def root():
"""根路径 - 服务状态检查"""
return {
"service": "双路径边界分类器文本分段服务",
"status": "运行中" if model is not None else "模型未加载",
"version": "3.0.0",
"model_type": "DualPathBoundaryClassifier",
"features": [
"双路径架构: 常规分类器 + 边界分类器",
"边界检测器: 纯神经网络学习边界模式",
"动态权重融合: 自适应边界识别",
"序列长度: 384 tokens",
"单条处理",
"批量处理(简化)",
"详细分析"
]
}
@app.get("/health")
async def health_check():
"""健康检查接口"""
return {
"status": "healthy" if model is not None else "unhealthy",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"max_sequence_length": 384,
"boundary_detection": "pure_neural_network"
}
@app.post("/segment_simple")
async def segment_text_simple(input_data: TextInput):
"""单条文本分段接口"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
result = process_single_broadcast_dual_path(input_data.广播内容)
if result["status"] == "success":
return result["segments"]
else:
return {"error": result["error"]}
except Exception as e:
return {"error": f"双路径处理失败: {str(e)}"}
@app.post("/segment_batch_simple")
async def segment_batch_simple(broadcasts: BatchInput):
"""批量文本分段接口 - 简化输出"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
if not broadcasts:
return {"error": "广播列表不能为空"}
start_time = time.time()
results = []
# 并行处理
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
future_to_broadcast = {
executor.submit(process_single_broadcast_dual_path,
broadcast.广播内容,
broadcast.广播ID): broadcast
for broadcast in broadcasts
}
for future in concurrent.futures.as_completed(future_to_broadcast):
try:
result = future.result()
results.append(result)
except Exception as e:
broadcast = future_to_broadcast[future]
results.append({
"broadcast_id": broadcast.广播ID,
"status": "failed",
"error": f"双路径处理异常: {str(e)}"
})
# 简化输出格式
simplified_results = {}
success_count = 0
for i, result in enumerate(results):
key = result.get("broadcast_id") or f"broadcast_{i + 1}"
if result["status"] == "success":
simplified_results[key] = result["segments"]
success_count += 1
else:
simplified_results[key] = {"error": result.get("error", "双路径处理失败")}
total_time = time.time() - start_time
return {
"model": "DualPathBoundaryClassifier",
"total": len(results),
"success": success_count,
"failed": len(results) - success_count,
"processing_time": round(total_time, 3),
"results": simplified_results
}
except Exception as e:
return {"error": f"双路径批量处理失败: {str(e)}"}
@app.post("/segment_with_details")
async def segment_with_details(input_data: TextInput):
"""带详细信息的文本分段接口"""
if model is None or tokenizer is None:
return {"error": "双路径模型未加载"}
try:
text = input_data.广播内容
sentences = split_text_into_sentences(text)
if len(sentences) < 2:
return {
"segments": {"paragraph_1": text},
"sentence_details": [],
"total_sentences": len(sentences)
}
results = {}
current_paragraph_sentences = [sentences[0]]
sentence_details = []
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1_clean = re.sub(r'[。!?;]$', '', sentences[i])
sentence2_clean = re.sub(r'[。!?;]$', '', sentences[i + 1])
encoding = tokenizer(
sentence1_clean,
sentence2_clean,
truncation=True,
padding='max_length',
max_length=384,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
regular_logits = outputs['regular_logits']
boundary_logits = outputs['boundary_logits']
boundary_score = outputs['boundary_score'].item()
prediction = torch.argmax(logits, dim=-1).item()
probability = torch.softmax(logits, dim=-1)
regular_prob = torch.softmax(regular_logits, dim=-1)
boundary_prob = torch.softmax(boundary_logits, dim=-1)
detail = {
"sentence_pair_index": i + 1,
"sentence1": sentences[i],
"sentence2": sentences[i + 1],
"prediction": prediction,
"prediction_label": "same_paragraph" if prediction == 0 else "different_paragraph",
"final_probabilities": {
"same_paragraph": float(probability[0][0]),
"different_paragraph": float(probability[0][1])
},
"regular_path_probabilities": {
"same_paragraph": float(regular_prob[0][0]),
"different_paragraph": float(regular_prob[0][1])
},
"boundary_path_probabilities": {
"same_paragraph": float(boundary_prob[0][0]),
"different_paragraph": float(boundary_prob[0][1])
},
"boundary_score": boundary_score,
"boundary_confidence": "high" if boundary_score > 0.7 else "medium" if boundary_score > 0.3 else "low"
}
sentence_details.append(detail)
if prediction == 0: # 同段落
current_paragraph_sentences.append(sentences[i + 1])
else: # 不同段落
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
current_paragraph_sentences = [sentences[i + 1]]
if current_paragraph_sentences:
paragraph_key = f"paragraph_{len(results) + 1}"
results[paragraph_key] = "".join(current_paragraph_sentences)
boundary_scores = [d["boundary_score"] for d in sentence_details]
return {
"segments": results,
"sentence_details": sentence_details,
"total_sentences": len(sentences),
"total_pairs_analyzed": len(sentence_details),
"boundary_statistics": {
"average_boundary_score": round(sum(boundary_scores) / len(boundary_scores),
4) if boundary_scores else 0,
"max_boundary_score": round(max(boundary_scores), 4) if boundary_scores else 0,
"min_boundary_score": round(min(boundary_scores), 4) if boundary_scores else 0
}
}
except Exception as e:
return {"error": f"详细分析失败: {str(e)}"}
@app.get("/stats")
async def get_processing_stats():
"""获取处理统计信息"""
return {
"service_status": "running" if model is not None else "down",
"model_loaded": model is not None,
"model_type": "DualPathBoundaryClassifier",
"device": str(device) if device else None,
"vocab_size": tokenizer.vocab_size if tokenizer else None,
"max_sequence_length": 384,
"api_endpoints": [
"/segment_simple - 单条处理",
"/segment_batch_simple - 批量处理(简化)",
"/segment_with_details - 带详细信息分段"
]
}
# ========== 启动配置 ==========
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=8888)