|
|
import json |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import math |
|
|
import os |
|
|
import re |
|
|
import time |
|
|
import concurrent.futures |
|
|
from typing import Dict, List, Optional, Union |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import BertTokenizer, BertModel |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
# ========== 双路径边界分类器模型定义部分 ========== |
|
|
class ScaledDotProductAttention(nn.Module): |
|
|
def __init__(self, d_model, dropout=0.1): |
|
|
super(ScaledDotProductAttention, self).__init__() |
|
|
self.d_model = d_model |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, query, key, value, mask=None): |
|
|
batch_size, seq_len, d_model = query.size() |
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) |
|
|
|
|
|
if mask is not None: |
|
|
mask_value = torch.finfo(scores.dtype).min |
|
|
scores = scores.masked_fill(mask == 0, mask_value) |
|
|
|
|
|
attention_weights = F.softmax(scores, dim=-1) |
|
|
attention_weights = self.dropout(attention_weights) |
|
|
output = torch.matmul(attention_weights, value) |
|
|
|
|
|
return output, attention_weights |
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
def __init__(self, alpha=None, gamma=3.0, reduction='mean'): |
|
|
super(FocalLoss, self).__init__() |
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, inputs, targets): |
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none') |
|
|
pt = torch.exp(-ce_loss) |
|
|
|
|
|
if self.alpha is not None: |
|
|
if self.alpha.type() != inputs.data.type(): |
|
|
self.alpha = self.alpha.type_as(inputs.data) |
|
|
at = self.alpha.gather(0, targets.data.view(-1)) |
|
|
ce_loss = ce_loss * at |
|
|
|
|
|
focal_weight = (1 - pt) ** self.gamma |
|
|
focal_loss = focal_weight * ce_loss |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
return focal_loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
return focal_loss.sum() |
|
|
else: |
|
|
return focal_loss |
|
|
|
|
|
|
|
|
class DualPathBoundaryClassifier(nn.Module): |
|
|
def __init__(self, model_path, num_labels=2, dropout=0.1, |
|
|
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0): |
|
|
super(DualPathBoundaryClassifier, self).__init__() |
|
|
|
|
|
self.roberta = BertModel.from_pretrained(model_path) |
|
|
self.config = self.roberta.config |
|
|
self.config.num_labels = num_labels |
|
|
|
|
|
self.scaled_attention = ScaledDotProductAttention( |
|
|
d_model=self.config.hidden_size, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
# 双路径分类器 |
|
|
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels) |
|
|
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels) |
|
|
self.boundary_detector = nn.Linear(self.config.hidden_size, 1) |
|
|
|
|
|
# 边界强制权重 |
|
|
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight)) |
|
|
|
|
|
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
self.focal_alpha = focal_alpha |
|
|
self.focal_gamma = focal_gamma |
|
|
|
|
|
def _init_weights(self): |
|
|
nn.init.normal_(self.regular_classifier.weight, std=0.02) |
|
|
nn.init.zeros_(self.regular_classifier.bias) |
|
|
nn.init.normal_(self.boundary_classifier.weight, std=0.02) |
|
|
nn.init.zeros_(self.boundary_classifier.bias) |
|
|
nn.init.normal_(self.boundary_detector.weight, std=0.02) |
|
|
nn.init.zeros_(self.boundary_detector.bias) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None): |
|
|
roberta_outputs = self.roberta( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
sequence_output = roberta_outputs.last_hidden_state |
|
|
|
|
|
# 缩放点积注意力增强 |
|
|
enhanced_output, attention_weights = self.scaled_attention( |
|
|
query=sequence_output, |
|
|
key=sequence_output, |
|
|
value=sequence_output, |
|
|
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None |
|
|
) |
|
|
|
|
|
cls_output = enhanced_output[:, 0, :] |
|
|
cls_output = self.dropout(cls_output) |
|
|
|
|
|
# 双路径分类 |
|
|
regular_logits = self.regular_classifier(cls_output) |
|
|
boundary_logits = self.boundary_classifier(cls_output) |
|
|
|
|
|
# 边界检测 |
|
|
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1) |
|
|
boundary_score = torch.sigmoid(boundary_logits_raw) |
|
|
|
|
|
# 动态融合 |
|
|
boundary_bias = torch.zeros_like(regular_logits) |
|
|
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight |
|
|
|
|
|
final_logits = regular_logits + boundary_bias |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
regular_loss = self.focal_loss(regular_logits, labels) |
|
|
boundary_loss = self.focal_loss(boundary_logits, labels) |
|
|
final_loss = self.focal_loss(final_logits, labels) |
|
|
|
|
|
boundary_labels = self._generate_boundary_labels(labels) |
|
|
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels) |
|
|
|
|
|
total_loss = (0.4 * final_loss + |
|
|
0.3 * regular_loss + |
|
|
0.2 * boundary_loss + |
|
|
0.1 * detection_loss) |
|
|
loss = total_loss |
|
|
|
|
|
return { |
|
|
'loss': loss, |
|
|
'logits': final_logits, |
|
|
'regular_logits': regular_logits, |
|
|
'boundary_logits': boundary_logits, |
|
|
'boundary_score': boundary_score, |
|
|
'hidden_states': enhanced_output, |
|
|
'attention_weights': attention_weights |
|
|
} |
|
|
|
|
|
def _generate_boundary_labels(self, labels): |
|
|
boundary_labels = labels.float() |
|
|
noise = torch.rand_like(boundary_labels) * 0.1 |
|
|
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0) |
|
|
return boundary_labels |
|
|
|
|
|
|
|
|
# ========== 模型加载部分 ========== |
|
|
def check_gpu_availability(): |
|
|
if torch.cuda.is_available(): |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 |
|
|
print(f"🚀 GPU: {gpu_name} ({gpu_memory:.1f} GB)") |
|
|
return torch.device('cuda') |
|
|
else: |
|
|
print("🔄 使用CPU") |
|
|
return torch.device('cpu') |
|
|
|
|
|
|
|
|
def safe_convert_focal_alpha(focal_alpha, device): |
|
|
if focal_alpha is None: |
|
|
return None |
|
|
try: |
|
|
if isinstance(focal_alpha, torch.Tensor): |
|
|
return focal_alpha.to(device) |
|
|
elif isinstance(focal_alpha, (list, tuple)): |
|
|
return torch.tensor(focal_alpha, dtype=torch.float32).to(device) |
|
|
elif isinstance(focal_alpha, np.ndarray): |
|
|
return torch.from_numpy(focal_alpha).float().to(device) |
|
|
elif isinstance(focal_alpha, (int, float)): |
|
|
return torch.tensor([focal_alpha], dtype=torch.float32).to(device) |
|
|
else: |
|
|
return torch.tensor(focal_alpha, dtype=torch.float32).to(device) |
|
|
except Exception as e: |
|
|
print(f"⚠️ 转换focal_alpha失败: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def load_trained_dual_path_model(model_path, device, original_roberta_path=None): |
|
|
try: |
|
|
if original_roberta_path and os.path.exists(original_roberta_path): |
|
|
tokenizer_path = original_roberta_path |
|
|
else: |
|
|
tokenizer_path = model_path |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained(tokenizer_path) |
|
|
|
|
|
config_path = os.path.join(model_path, 'config.json') |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
config = json.load(f) |
|
|
focal_gamma = config.get('focal_gamma', 3.0) |
|
|
focal_alpha_raw = config.get('focal_alpha', None) |
|
|
boundary_force_weight = config.get('boundary_force_weight', 2.0) |
|
|
else: |
|
|
focal_gamma = 3.0 |
|
|
focal_alpha_raw = None |
|
|
boundary_force_weight = 2.0 |
|
|
|
|
|
focal_alpha = safe_convert_focal_alpha(focal_alpha_raw, device) |
|
|
|
|
|
if original_roberta_path and os.path.exists(original_roberta_path): |
|
|
model_init_path = original_roberta_path |
|
|
else: |
|
|
model_init_path = tokenizer_path |
|
|
|
|
|
model = DualPathBoundaryClassifier( |
|
|
model_path=model_init_path, |
|
|
num_labels=2, |
|
|
dropout=0.1, |
|
|
focal_alpha=focal_alpha, |
|
|
focal_gamma=focal_gamma, |
|
|
boundary_force_weight=boundary_force_weight |
|
|
) |
|
|
|
|
|
model_vocab_size = model.roberta.embeddings.word_embeddings.weight.shape[0] |
|
|
if model_vocab_size != tokenizer.vocab_size: |
|
|
raise ValueError(f"词汇表大小不匹配: 模型={model_vocab_size}, Tokenizer={tokenizer.vocab_size}") |
|
|
|
|
|
model_weights_path = os.path.join(model_path, 'pytorch_model.bin') |
|
|
if os.path.exists(model_weights_path): |
|
|
try: |
|
|
state_dict = torch.load(model_weights_path, map_location=device) |
|
|
except Exception: |
|
|
state_dict = torch.load(model_weights_path, map_location='cpu') |
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
if missing_keys or unexpected_keys: |
|
|
print(f"⚠️ 加载权重时有差异: 缺少{len(missing_keys)}个键, 多余{len(unexpected_keys)}个键") |
|
|
else: |
|
|
raise FileNotFoundError(f"未找到模型权重文件: {model_weights_path}") |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print("✅ 双路径边界分类器模型加载完成") |
|
|
return model, tokenizer |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 双路径模型加载失败: {str(e)}") |
|
|
raise |
|
|
|
|
|
|
|
|
# ========== 文本处理部分 ========== |
|
|
def split_text_into_sentences(text: str) -> List[str]: |
|
|
text = text.strip() |
|
|
sentence_endings = r'([。!?;])' |
|
|
parts = re.split(sentence_endings, text) |
|
|
|
|
|
sentences = [] |
|
|
for i in range(0, len(parts), 2): |
|
|
if i < len(parts): |
|
|
sentence = parts[i].strip() |
|
|
if sentence: |
|
|
if i + 1 < len(parts): |
|
|
sentence += parts[i + 1] |
|
|
sentences.append(sentence) |
|
|
|
|
|
return sentences |
|
|
|
|
|
|
|
|
def predict_sentence_pairs_dual_path(sentences: List[str], model, tokenizer, device, max_length=384) -> Dict[str, str]: |
|
|
if len(sentences) < 2: |
|
|
return {"paragraph_1": sentences[0] if sentences else ""} |
|
|
|
|
|
results = {} |
|
|
current_paragraph_sentences = [sentences[0]] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(len(sentences) - 1): |
|
|
sentence1_clean = re.sub(r'[。!?;]$', '', sentences[i]) |
|
|
sentence2_clean = re.sub(r'[。!?;]$', '', sentences[i + 1]) |
|
|
|
|
|
encoding = tokenizer( |
|
|
sentence1_clean, |
|
|
sentence2_clean, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
input_ids = encoding['input_ids'].to(device) |
|
|
attention_mask = encoding['attention_mask'].to(device) |
|
|
|
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
logits = outputs['logits'] |
|
|
prediction = torch.argmax(logits, dim=-1).item() |
|
|
|
|
|
if prediction == 0: # 同段落 |
|
|
current_paragraph_sentences.append(sentences[i + 1]) |
|
|
else: # 不同段落 |
|
|
paragraph_key = f"paragraph_{len(results) + 1}" |
|
|
results[paragraph_key] = "".join(current_paragraph_sentences) |
|
|
current_paragraph_sentences = [sentences[i + 1]] |
|
|
|
|
|
if current_paragraph_sentences: |
|
|
paragraph_key = f"paragraph_{len(results) + 1}" |
|
|
results[paragraph_key] = "".join(current_paragraph_sentences) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def process_single_broadcast_dual_path(text: str, broadcast_id: Optional[str] = None) -> dict: |
|
|
try: |
|
|
if not text or not text.strip(): |
|
|
return { |
|
|
"broadcast_id": broadcast_id, |
|
|
"segments": {}, |
|
|
"status": "failed", |
|
|
"error": "文本为空" |
|
|
} |
|
|
|
|
|
sentences = split_text_into_sentences(text) |
|
|
|
|
|
if len(sentences) == 0: |
|
|
return { |
|
|
"broadcast_id": broadcast_id, |
|
|
"segments": {"paragraph_1": text}, |
|
|
"status": "success" |
|
|
} |
|
|
|
|
|
segments = predict_sentence_pairs_dual_path(sentences, model, tokenizer, device) |
|
|
|
|
|
return { |
|
|
"broadcast_id": broadcast_id, |
|
|
"segments": segments, |
|
|
"status": "success" |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return { |
|
|
"broadcast_id": broadcast_id, |
|
|
"segments": {}, |
|
|
"status": "failed", |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
|
|
|
# ========== FastAPI应用部分 ========== |
|
|
app = FastAPI(title="双路径边界分类器文本分段服务", version="3.0.0") |
|
|
|
|
|
# 全局变量存储模型 |
|
|
model = None |
|
|
tokenizer = None |
|
|
device = None |
|
|
|
|
|
|
|
|
# ========== 请求和响应模型 ========== |
|
|
class TextInput(BaseModel): |
|
|
广播内容: str |
|
|
|
|
|
|
|
|
class BroadcastItem(BaseModel): |
|
|
广播内容: str |
|
|
广播ID: Optional[str] = None |
|
|
|
|
|
|
|
|
BatchInput = List[BroadcastItem] |
|
|
|
|
|
|
|
|
# ========== 生命周期事件 ========== |
|
|
@app.on_event("startup") |
|
|
async def load_model(): |
|
|
global model, tokenizer, device |
|
|
|
|
|
model_path = "/work/model_robert/model_train-eval" |
|
|
original_roberta_path = "/work/model_robert/model" |
|
|
|
|
|
print("🚀 正在启动双路径边界分类器文本分段服务...") |
|
|
|
|
|
try: |
|
|
device = check_gpu_availability() |
|
|
model, tokenizer = load_trained_dual_path_model(model_path, device, original_roberta_path) |
|
|
|
|
|
print(f"✅ 双路径边界分类器模型加载成功! 设备: {device}") |
|
|
# print(f"📝 词汇表大小: {tokenizer.vocab_size}") |
|
|
# print(f"🎯 序列最大长度: 384 tokens") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ 双路径模型加载失败: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
# ========== API接口 ========== |
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""根路径 - 服务状态检查""" |
|
|
return { |
|
|
"service": "双路径边界分类器文本分段服务", |
|
|
"status": "运行中" if model is not None else "模型未加载", |
|
|
"version": "3.0.0", |
|
|
"model_type": "DualPathBoundaryClassifier", |
|
|
"features": [ |
|
|
"双路径架构: 常规分类器 + 边界分类器", |
|
|
"边界检测器: 纯神经网络学习边界模式", |
|
|
"动态权重融合: 自适应边界识别", |
|
|
"序列长度: 384 tokens", |
|
|
"单条处理", |
|
|
"批量处理(简化)", |
|
|
"详细分析" |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""健康检查接口""" |
|
|
return { |
|
|
"status": "healthy" if model is not None else "unhealthy", |
|
|
"model_loaded": model is not None, |
|
|
"model_type": "DualPathBoundaryClassifier", |
|
|
"device": str(device) if device else None, |
|
|
"max_sequence_length": 384, |
|
|
"boundary_detection": "pure_neural_network" |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/segment_simple") |
|
|
async def segment_text_simple(input_data: TextInput): |
|
|
"""单条文本分段接口""" |
|
|
if model is None or tokenizer is None: |
|
|
return {"error": "双路径模型未加载"} |
|
|
|
|
|
try: |
|
|
result = process_single_broadcast_dual_path(input_data.广播内容) |
|
|
|
|
|
if result["status"] == "success": |
|
|
return result["segments"] |
|
|
else: |
|
|
return {"error": result["error"]} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"双路径处理失败: {str(e)}"} |
|
|
|
|
|
|
|
|
@app.post("/segment_batch_simple") |
|
|
async def segment_batch_simple(broadcasts: BatchInput): |
|
|
"""批量文本分段接口 - 简化输出""" |
|
|
if model is None or tokenizer is None: |
|
|
return {"error": "双路径模型未加载"} |
|
|
|
|
|
try: |
|
|
if not broadcasts: |
|
|
return {"error": "广播列表不能为空"} |
|
|
|
|
|
start_time = time.time() |
|
|
results = [] |
|
|
|
|
|
# 并行处理 |
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: |
|
|
future_to_broadcast = { |
|
|
executor.submit(process_single_broadcast_dual_path, |
|
|
broadcast.广播内容, |
|
|
broadcast.广播ID): broadcast |
|
|
for broadcast in broadcasts |
|
|
} |
|
|
|
|
|
for future in concurrent.futures.as_completed(future_to_broadcast): |
|
|
try: |
|
|
result = future.result() |
|
|
results.append(result) |
|
|
except Exception as e: |
|
|
broadcast = future_to_broadcast[future] |
|
|
results.append({ |
|
|
"broadcast_id": broadcast.广播ID, |
|
|
"status": "failed", |
|
|
"error": f"双路径处理异常: {str(e)}" |
|
|
}) |
|
|
|
|
|
# 简化输出格式 |
|
|
simplified_results = {} |
|
|
success_count = 0 |
|
|
|
|
|
for i, result in enumerate(results): |
|
|
key = result.get("broadcast_id") or f"broadcast_{i + 1}" |
|
|
|
|
|
if result["status"] == "success": |
|
|
simplified_results[key] = result["segments"] |
|
|
success_count += 1 |
|
|
else: |
|
|
simplified_results[key] = {"error": result.get("error", "双路径处理失败")} |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
return { |
|
|
"model": "DualPathBoundaryClassifier", |
|
|
"total": len(results), |
|
|
"success": success_count, |
|
|
"failed": len(results) - success_count, |
|
|
"processing_time": round(total_time, 3), |
|
|
"results": simplified_results |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"双路径批量处理失败: {str(e)}"} |
|
|
|
|
|
|
|
|
@app.post("/segment_with_details") |
|
|
async def segment_with_details(input_data: TextInput): |
|
|
"""带详细信息的文本分段接口""" |
|
|
if model is None or tokenizer is None: |
|
|
return {"error": "双路径模型未加载"} |
|
|
|
|
|
try: |
|
|
text = input_data.广播内容 |
|
|
sentences = split_text_into_sentences(text) |
|
|
|
|
|
if len(sentences) < 2: |
|
|
return { |
|
|
"segments": {"paragraph_1": text}, |
|
|
"sentence_details": [], |
|
|
"total_sentences": len(sentences) |
|
|
} |
|
|
|
|
|
results = {} |
|
|
current_paragraph_sentences = [sentences[0]] |
|
|
sentence_details = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(len(sentences) - 1): |
|
|
sentence1_clean = re.sub(r'[。!?;]$', '', sentences[i]) |
|
|
sentence2_clean = re.sub(r'[。!?;]$', '', sentences[i + 1]) |
|
|
|
|
|
encoding = tokenizer( |
|
|
sentence1_clean, |
|
|
sentence2_clean, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=384, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
input_ids = encoding['input_ids'].to(device) |
|
|
attention_mask = encoding['attention_mask'].to(device) |
|
|
|
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
logits = outputs['logits'] |
|
|
regular_logits = outputs['regular_logits'] |
|
|
boundary_logits = outputs['boundary_logits'] |
|
|
boundary_score = outputs['boundary_score'].item() |
|
|
|
|
|
prediction = torch.argmax(logits, dim=-1).item() |
|
|
probability = torch.softmax(logits, dim=-1) |
|
|
regular_prob = torch.softmax(regular_logits, dim=-1) |
|
|
boundary_prob = torch.softmax(boundary_logits, dim=-1) |
|
|
|
|
|
detail = { |
|
|
"sentence_pair_index": i + 1, |
|
|
"sentence1": sentences[i], |
|
|
"sentence2": sentences[i + 1], |
|
|
"prediction": prediction, |
|
|
"prediction_label": "same_paragraph" if prediction == 0 else "different_paragraph", |
|
|
"final_probabilities": { |
|
|
"same_paragraph": float(probability[0][0]), |
|
|
"different_paragraph": float(probability[0][1]) |
|
|
}, |
|
|
"regular_path_probabilities": { |
|
|
"same_paragraph": float(regular_prob[0][0]), |
|
|
"different_paragraph": float(regular_prob[0][1]) |
|
|
}, |
|
|
"boundary_path_probabilities": { |
|
|
"same_paragraph": float(boundary_prob[0][0]), |
|
|
"different_paragraph": float(boundary_prob[0][1]) |
|
|
}, |
|
|
"boundary_score": boundary_score, |
|
|
"boundary_confidence": "high" if boundary_score > 0.7 else "medium" if boundary_score > 0.3 else "low" |
|
|
} |
|
|
sentence_details.append(detail) |
|
|
|
|
|
if prediction == 0: # 同段落 |
|
|
current_paragraph_sentences.append(sentences[i + 1]) |
|
|
else: # 不同段落 |
|
|
paragraph_key = f"paragraph_{len(results) + 1}" |
|
|
results[paragraph_key] = "".join(current_paragraph_sentences) |
|
|
current_paragraph_sentences = [sentences[i + 1]] |
|
|
|
|
|
if current_paragraph_sentences: |
|
|
paragraph_key = f"paragraph_{len(results) + 1}" |
|
|
results[paragraph_key] = "".join(current_paragraph_sentences) |
|
|
|
|
|
boundary_scores = [d["boundary_score"] for d in sentence_details] |
|
|
|
|
|
return { |
|
|
"segments": results, |
|
|
"sentence_details": sentence_details, |
|
|
"total_sentences": len(sentences), |
|
|
"total_pairs_analyzed": len(sentence_details), |
|
|
"boundary_statistics": { |
|
|
"average_boundary_score": round(sum(boundary_scores) / len(boundary_scores), |
|
|
4) if boundary_scores else 0, |
|
|
"max_boundary_score": round(max(boundary_scores), 4) if boundary_scores else 0, |
|
|
"min_boundary_score": round(min(boundary_scores), 4) if boundary_scores else 0 |
|
|
} |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
return {"error": f"详细分析失败: {str(e)}"} |
|
|
|
|
|
|
|
|
@app.get("/stats") |
|
|
async def get_processing_stats(): |
|
|
"""获取处理统计信息""" |
|
|
return { |
|
|
"service_status": "running" if model is not None else "down", |
|
|
"model_loaded": model is not None, |
|
|
"model_type": "DualPathBoundaryClassifier", |
|
|
"device": str(device) if device else None, |
|
|
"vocab_size": tokenizer.vocab_size if tokenizer else None, |
|
|
"max_sequence_length": 384, |
|
|
"api_endpoints": [ |
|
|
"/segment_simple - 单条处理", |
|
|
"/segment_batch_simple - 批量处理(简化)", |
|
|
"/segment_with_details - 带详细信息分段" |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
# ========== 启动配置 ========== |
|
|
if __name__ == "__main__": |
|
|
|
|
|
uvicorn.run(app, host='0.0.0.0', port=8888) |