You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

996 lines
33 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
import os
import math
import logging
import re
from datetime import datetime
from typing import List, Dict, Any, Optional
from collections import Counter
from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from transformers import BertTokenizer, BertModel
import uvicorn
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 全局变量
model = None
tokenizer = None
device = None
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力机制"""
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class DualPathBoundaryClassifier(nn.Module):
"""双路径边界分类器,完全依靠神经网络学习边界模式"""
def __init__(self, model_path, num_labels=2, dropout=0.1, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# 双路径分类器
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# 边界强制权重
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
def forward(self, input_ids, attention_mask=None, token_type_ids=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# 缩放点积注意力增强
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# 双路径分类
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# 边界检测
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# 动态融合
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
return {
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def load_model():
"""加载训练好的模型"""
global model, tokenizer, device
# 检查GPU
if torch.cuda.is_available():
device = torch.device('cuda')
gpu_name = torch.cuda.get_device_name(0)
logger.info(f"✅ 使用GPU: {gpu_name}")
else:
device = torch.device('cpu')
logger.info(" 使用CPU运行")
# 模型路径配置
model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train"
original_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model"
logger.info(f"📥 加载双路径边界分类器模型...")
try:
# 先检查训练模型目录是否存在配置文件
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
model_config = json.load(f)
boundary_force_weight = model_config.get('boundary_force_weight', 2.0)
logger.info(f" 🔹 边界强制权重: {boundary_force_weight}")
else:
boundary_force_weight = 2.0
# 先尝试从训练目录加载tokenizer,如果失败则使用原始目录
try:
tokenizer = BertTokenizer.from_pretrained(model_path)
logger.info(f" ✅ 从训练目录加载tokenizer成功")
except Exception as e:
logger.warning(f" 从训练目录加载tokenizer失败: {str(e)}")
tokenizer = BertTokenizer.from_pretrained(original_model_path)
logger.info(f" ✅ 从原始目录加载tokenizer成功")
logger.info(f" 🔹 词汇表大小: {len(tokenizer.vocab)}")
# 创建模型实例,使用原始模型路径以避免词汇表不匹配
logger.info(f" 🔧 使用原始模型路径创建模型实例")
model = DualPathBoundaryClassifier(
model_path=original_model_path, # 强制使用原始模型路径
num_labels=2,
dropout=0.1,
boundary_force_weight=boundary_force_weight
)
# 加载训练好的权重
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
logger.info(f" 📥 加载训练权重...")
state_dict = torch.load(model_weights_path, map_location=device)
# 尝试加载权重,如果失败则使用更安全的方法
try:
model.load_state_dict(state_dict)
logger.info(f" ✅ 成功加载完整权重")
except RuntimeError as e:
if "size mismatch" in str(e):
logger.warning(f" 检测到权重尺寸不匹配,使用兼容性加载")
# 过滤掉不匹配的权重
model_dict = model.state_dict()
filtered_dict = {}
for k, v in state_dict.items():
if k in model_dict:
if model_dict[k].shape == v.shape:
filtered_dict[k] = v
else:
logger.warning(
f" 跳过不匹配的权重: {k} (模型: {model_dict[k].shape}, 检查点: {v.shape})")
else:
logger.warning(f" 跳过未知权重: {k}")
# 加载过滤后的权重
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)
logger.info(f" ✅ 成功加载兼容权重 ({len(filtered_dict)}/{len(state_dict)} 权重已加载)")
else:
raise e
else:
logger.error(f" ❌ 找不到模型权重文件: {model_weights_path}")
return False
model.to(device)
model.eval()
total_params = sum(p.numel() for p in model.parameters())
logger.info(f"📊 模型参数: {total_params:,}")
# 测试模型是否正常工作
logger.info("🧪 测试模型推理...")
test_result = test_model_inference()
if test_result:
logger.info("✅ 模型推理测试通过")
logger.info("🚀 模型加载完成,Ready for service!")
return True
else:
logger.error("❌ 模型推理测试失败")
return False
except Exception as e:
logger.error(f"❌ 模型加载失败: {str(e)}")
import traceback
traceback.print_exc()
return False
def test_model_inference():
"""测试模型推理是否正常"""
try:
test_sentences = [
"这是第一个测试句子。",
"这是第二个测试句子。"
]
with torch.no_grad():
inputs = tokenizer(
test_sentences[0],
test_sentences[1],
truncation=True,
padding=True,
max_length=512,
return_tensors='pt'
)
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
# 检查输出格式
required_keys = ['logits', 'boundary_score']
for key in required_keys:
if key not in outputs:
logger.error(f"模型输出缺少必要的键: {key}")
return False
logits = outputs['logits']
boundary_score = outputs['boundary_score']
# 检查输出形状
if logits.shape != torch.Size([1, 2]):
logger.error(f"logits形状不正确: {logits.shape}, 期望: [1, 2]")
return False
if boundary_score.shape != torch.Size([1]):
logger.error(f"boundary_score形状不正确: {boundary_score.shape}, 期望: [1]")
return False
# 检查数值范围
if not (0 <= boundary_score.item() <= 1):
logger.error(f"boundary_score超出范围: {boundary_score.item()}")
return False
logger.info(f" 测试预测: logits={logits.tolist()}, boundary_score={boundary_score.item():.3f}")
return True
except Exception as e:
logger.error(f"模型推理测试异常: {str(e)}")
return False
def split_text_into_sentences(text: str) -> List[str]:
"""将文本按句号、感叹号、问号分割成句子"""
# 中文句子分割规则
sentence_endings = r'[。!?!?]'
sentences = re.split(sentence_endings, text)
# 过滤空句子,保留标点符号
result = []
for i, sentence in enumerate(sentences):
sentence = sentence.strip()
if sentence:
# 如果不是最后一个句子,添加标点符号
if i < len(sentences) - 1:
# 找到原始标点符号
original_text = text
start_pos = 0
for j in range(i):
if sentences[j].strip():
start_pos = original_text.find(sentences[j].strip(), start_pos) + len(sentences[j].strip())
# 查找句子后的标点符号
remaining_text = original_text[start_pos:]
punctuation_match = re.search(sentence_endings, remaining_text)
if punctuation_match:
sentence += punctuation_match.group()
result.append(sentence)
return result
def predict_sentence_pairs(sentences: List[str]) -> List[Dict[str, Any]]:
"""预测相邻句子对是否需要分段"""
if len(sentences) < 2:
return []
results = []
with torch.no_grad():
for i in range(len(sentences) - 1):
sentence1 = sentences[i]
sentence2 = sentences[i + 1]
# Tokenization
inputs = tokenizer(
sentence1,
sentence2,
truncation=True,
padding=True,
max_length=512,
return_tensors='pt'
)
# 移动到设备
inputs = {k: v.to(device) for k, v in inputs.items()}
# 模型预测
outputs = model(**inputs)
# 获取预测结果
logits = outputs['logits']
boundary_score = outputs['boundary_score']
probs = F.softmax(logits, dim=-1)
prediction = torch.argmax(logits, dim=-1).item()
confidence = torch.max(probs, dim=-1)[0].item()
boundary_score_value = boundary_score.item()
# 结果
result = {
'sentence1': sentence1,
'sentence2': sentence2,
'prediction': prediction, # 0: 同段落, 1: 不同段落
'confidence': confidence,
'boundary_score': boundary_score_value,
'should_split': prediction == 1,
'split_reason': get_split_reason(prediction, boundary_score_value, confidence)
}
results.append(result)
return results
def get_split_reason(prediction: int, boundary_score: float, confidence: float) -> str:
"""生成分段原因说明"""
if prediction == 1:
if boundary_score > 0.7:
return f"检测到强边界信号 (边界分数: {boundary_score:.3f})"
elif confidence > 0.8:
return f"语义转换明显 (置信度: {confidence:.3f})"
else:
return f"建议分段 (置信度: {confidence:.3f})"
else:
return f"内容连贯,无需分段 (置信度: {confidence:.3f})"
def segment_text(text: str) -> Dict[str, Any]:
"""对完整文本进行分段处理"""
# 分割成句子
sentences = split_text_into_sentences(text)
if len(sentences) < 2:
return {
'original_text': text,
'sentences': sentences,
'segments': [text] if text.strip() else [],
'split_decisions': [],
'total_sentences': len(sentences),
'total_segments': 1 if text.strip() else 0
}
# 预测相邻句子对
split_decisions = predict_sentence_pairs(sentences)
# 根据预测结果进行分段
segments = []
current_segment = [sentences[0]]
for i, decision in enumerate(split_decisions):
if decision['should_split']:
# 需要分段,结束当前段落
segments.append(''.join(current_segment))
current_segment = [sentences[i + 1]]
else:
# 不需要分段,继续当前段落
current_segment.append(sentences[i + 1])
# 添加最后一个段落
if current_segment:
segments.append(''.join(current_segment))
return {
'original_text': text,
'sentences': sentences,
'segments': segments,
'split_decisions': split_decisions,
'total_sentences': len(sentences),
'total_segments': len(segments)
}
# FastAPI 应用
app = FastAPI(title="双路径边界分类器文本分段服务", version="1.0.0")
# 请求模型
class TextInput(BaseModel):
text: str
class BatchTextInput(BaseModel):
texts: List[str]
# 响应模型
class SegmentResult(BaseModel):
original_text: str
sentences: List[str]
segments: List[str]
total_sentences: int
total_segments: int
class DetailedSegmentResult(BaseModel):
original_text: str
sentences: List[str]
segments: List[str]
split_decisions: List[Dict[str, Any]]
total_sentences: int
total_segments: int
@app.on_event("startup")
async def startup_event():
"""启动时加载模型"""
logger.info("🚀 启动双路径边界分类器服务...")
success = load_model()
if not success:
logger.error("❌ 模型加载失败,服务无法启动")
raise RuntimeError("模型加载失败")
logger.info("✅ 服务启动成功!")
@app.get("/", response_class=HTMLResponse)
async def get_frontend():
"""返回前端页面"""
html_content = """
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>双路径边界分类器 - 文本分段服务</title>
<style>
body {
font-family: 'Arial', sans-serif;
margin: 0;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
}
.container {
max-width: 1200px;
margin: 0 auto;
background: white;
border-radius: 15px;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
overflow: hidden;
}
.header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 30px;
text-align: center;
}
.header h1 {
margin: 0;
font-size: 2.5em;
font-weight: bold;
}
.header p {
margin: 10px 0 0 0;
font-size: 1.2em;
opacity: 0.9;
}
.content {
padding: 30px;
}
.input-section {
margin-bottom: 30px;
}
.input-section label {
display: block;
margin-bottom: 10px;
font-weight: bold;
color: #333;
font-size: 1.1em;
}
.input-section textarea {
width: 100%;
height: 200px;
padding: 15px;
border: 2px solid #ddd;
border-radius: 10px;
font-size: 16px;
font-family: 'Arial', sans-serif;
resize: vertical;
box-sizing: border-box;
}
.input-section textarea:focus {
outline: none;
border-color: #667eea;
box-shadow: 0 0 10px rgba(102, 126, 234, 0.3);
}
.button-group {
display: flex;
gap: 15px;
margin-top: 20px;
}
.btn {
padding: 12px 25px;
border: none;
border-radius: 8px;
font-size: 16px;
font-weight: bold;
cursor: pointer;
transition: all 0.3s ease;
}
.btn-primary {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
}
.btn-primary:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
.btn-secondary {
background: #f8f9fa;
color: #333;
border: 2px solid #ddd;
}
.btn-secondary:hover {
background: #e9ecef;
}
.btn:disabled {
opacity: 0.6;
cursor: not-allowed;
transform: none !important;
}
.loading {
display: none;
text-align: center;
padding: 20px;
color: #667eea;
font-weight: bold;
}
.loading::after {
content: "";
display: inline-block;
margin-left: 10px;
width: 20px;
height: 20px;
border: 3px solid #f3f3f3;
border-top: 3px solid #667eea;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
.result-section {
margin-top: 30px;
display: none;
}
.result-header {
background: #f8f9fa;
padding: 15px;
border-radius: 10px;
margin-bottom: 20px;
border-left: 5px solid #667eea;
}
.segment {
background: #f8f9fa;
padding: 15px;
margin: 10px 0;
border-radius: 8px;
border-left: 4px solid #28a745;
position: relative;
}
.segment-header {
font-weight: bold;
color: #28a745;
margin-bottom: 8px;
}
.segment-content {
line-height: 1.6;
color: #333;
}
.split-decision {
background: white;
padding: 10px 15px;
margin: 8px 0;
border-radius: 6px;
border-left: 3px solid #ffc107;
font-size: 14px;
}
.split-decision.split {
border-left-color: #dc3545;
background: #fff5f5;
}
.split-decision.no-split {
border-left-color: #28a745;
background: #f5fff5;
}
.stats {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 15px;
margin: 20px 0;
}
.stat-card {
background: white;
padding: 20px;
border-radius: 10px;
text-align: center;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
.stat-number {
font-size: 2em;
font-weight: bold;
color: #667eea;
}
.stat-label {
color: #666;
margin-top: 5px;
}
.demo-buttons {
display: flex;
gap: 10px;
margin-top: 15px;
flex-wrap: wrap;
}
.demo-btn {
padding: 8px 15px;
background: #e9ecef;
border: 1px solid #ddd;
border-radius: 5px;
cursor: pointer;
font-size: 14px;
transition: all 0.2s ease;
}
.demo-btn:hover {
background: #dee2e6;
border-color: #667eea;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🤖 双路径边界分类器</h1>
<p>智能文本分段服务 - 基于神经网络的广播内容自动分段</p>
</div>
<div class="content">
<div class="input-section">
<label for="textInput">📝 请输入需要分段的文本内容:</label>
<textarea id="textInput" placeholder="请输入广播内容或其他需要分段的文本...&#10;&#10;示例:&#10;小月提醒大家,不认识的野菜一定不要去采,避免因误食发生过敏或者中毒。好了,健康快车赶快上车,我是小月,我们下期再会。下面即将收听到的是普法档案。欢迎收听普法档案,我是沐白。"></textarea>
<div class="demo-buttons">
<button class="demo-btn" onclick="loadDemo(1)">📻 广播节目示例</button>
<button class="demo-btn" onclick="loadDemo(2)">📰 新闻内容示例</button>
<button class="demo-btn" onclick="loadDemo(3)">📚 教育内容示例</button>
<button class="demo-btn" onclick="clearText()">🗑 清空</button>
</div>
<div class="button-group">
<button class="btn btn-primary" onclick="processText()">🚀 开始分段</button>
<button class="btn btn-secondary" onclick="clearResults()">🔄 清除结果</button>
</div>
</div>
<div class="loading" id="loading">
正在分析文本,请稍候...
</div>
<div class="result-section" id="resultSection">
<div class="result-header">
<h3>📊 分段结果</h3>
<div class="stats" id="stats"></div>
</div>
<div id="segments"></div>
<details style="margin-top: 20px;">
<summary style="cursor: pointer; font-weight: bold; color: #667eea;">🔍 查看详细分析过程</summary>
<div id="detailedAnalysis" style="margin-top: 15px;"></div>
</details>
</div>
</div>
</div>
<script>
const demoTexts = {
1: `小月提醒大家,不认识的野菜一定不要去采,避免因误食发生过敏或者中毒。食用野菜前最好留存少许的野菜或者先拍照,一旦发生不适要停止食用,立即催吐,然后携带剩余野菜呕吐物或者之前拍的照片及时就医。好了,健康快车赶快上车,我是小月,我们下期再会。下面即将收听到的是普法档案。欢迎收听普法档案,我是沐白。今天我们来讨论一个重要的法律问题。这个问题涉及到合同纠纷的处理方式。感谢大家收听今天的节目内容。接下来为您播放轻音乐时光。`,
2: `今日上午,市政府召开新闻发布会,宣布了新的城市规划方案。该方案将重点发展科技创新产业,预计投资总额达到500亿元。据了解,新规划将涵盖教育、医疗、交通等多个领域。以上就是今天的新闻内容。现在为大家播放天气预报。明天将是晴朗的一天,气温在15到25度之间。请大家注意适当增减衣物。`,
3: `在学习语言的过程中,我们需要掌握几个重要的原则。首先是要多听多说,培养语感。其次是要注意语法的正确性,避免常见错误。最后是要多阅读,扩大词汇量。今天的语言学习课程就到这里。下面请收听音乐欣赏节目。今天为大家介绍的是古典音乐的魅力。音乐能够陶冶情操,提升审美能力。`
};
function loadDemo(num) {
document.getElementById('textInput').value = demoTexts[num];
}
function clearText() {
document.getElementById('textInput').value = '';
}
function clearResults() {
document.getElementById('resultSection').style.display = 'none';
}
async function processText() {
const text = document.getElementById('textInput').value.trim();
if (!text) {
alert('请输入要分段的文本内容!');
return;
}
// 显示加载状态
document.getElementById('loading').style.display = 'block';
document.getElementById('resultSection').style.display = 'none';
try {
const response = await fetch('/segment', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({ text: text })
});
if (!response.ok) {
throw new Error('网络请求失败');
}
const result = await response.json();
displayResults(result);
} catch (error) {
alert('处理失败:' + error.message);
} finally {
document.getElementById('loading').style.display = 'none';
}
}
function displayResults(result) {
// 显示统计信息
const statsHtml = `
<div class="stat-card">
<div class="stat-number">${result.total_sentences}</div>
<div class="stat-label">总句子数</div>
</div>
<div class="stat-card">
<div class="stat-number">${result.total_segments}</div>
<div class="stat-label">分段数量</div>
</div>
<div class="stat-card">
<div class="stat-number">${(result.split_decisions || []).filter(d => d.should_split).length}</div>
<div class="stat-label">分段点数</div>
</div>
`;
document.getElementById('stats').innerHTML = statsHtml;
// 显示分段结果
const segmentsHtml = result.segments.map((segment, index) => `
<div class="segment">
<div class="segment-header">段落 ${index + 1}</div>
<div class="segment-content">${segment}</div>
</div>
`).join('');
document.getElementById('segments').innerHTML = segmentsHtml;
// 显示详细分析
if (result.split_decisions && result.split_decisions.length > 0) {
const analysisHtml = result.split_decisions.map((decision, index) => `
<div class="split-decision ${decision.should_split ? 'split' : 'no-split'}">
<strong>句子对 ${index + 1}:</strong><br>
<div style="margin: 5px 0;">
<strong>句子1:</strong> ${decision.sentence1}<br>
<strong>句子2:</strong> ${decision.sentence2}
</div>
<div style="margin: 5px 0;">
<strong>决策:</strong> ${decision.should_split ? '🔴 需要分段' : '🟢 无需分段'}<br>
<strong>置信度:</strong> ${(decision.confidence * 100).toFixed(1)}%<br>
<strong>边界分数:</strong> ${(decision.boundary_score * 100).toFixed(1)}%<br>
<strong>原因:</strong> ${decision.split_reason}
</div>
</div>
`).join('');
document.getElementById('detailedAnalysis').innerHTML = analysisHtml;
}
document.getElementById('resultSection').style.display = 'block';
document.getElementById('resultSection').scrollIntoView({ behavior: 'smooth' });
}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.post("/segment", response_model=DetailedSegmentResult)
async def segment_text_api(input_data: TextInput):
"""文本分段API"""
try:
if not input_data.text.strip():
raise HTTPException(status_code=400, detail="输入文本不能为空")
result = segment_text(input_data.text)
return DetailedSegmentResult(**result)
except Exception as e:
logger.error(f"文本分段处理失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@app.post("/segment/batch")
async def segment_batch_api(input_data: BatchTextInput):
"""批量文本分段API"""
try:
if not input_data.texts:
raise HTTPException(status_code=400, detail="输入文本列表不能为空")
results = []
for i, text in enumerate(input_data.texts):
if text.strip():
result = segment_text(text)
result['text_index'] = i
results.append(result)
return {"results": results, "total_processed": len(results)}
except Exception as e:
logger.error(f"批量文本分段处理失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}")
@app.get("/health")
async def health_check():
"""健康检查API"""
return {
"status": "healthy",
"model_loaded": model is not None,
"device": str(device) if device else "unknown",
"timestamp": datetime.now().isoformat()
}
@app.get("/model/info")
async def model_info():
"""模型信息API"""
if model is None:
raise HTTPException(status_code=503, detail="模型未加载")
try:
total_params = sum(p.numel() for p in model.parameters())
return {
"model_type": "DualPathBoundaryClassifier",
"total_parameters": total_params,
"device": str(device),
"boundary_force_weight": float(model.boundary_force_weight.data),
"vocab_size": len(tokenizer.vocab),
"max_length": 512
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取模型信息失败: {str(e)}")
@app.post("/predict/pair")
async def predict_sentence_pair(sentence1: str, sentence2: str):
"""预测单个句子对是否需要分段"""
try:
if not sentence1.strip() or not sentence2.strip():
raise HTTPException(status_code=400, detail="句子不能为空")
# 使用现有的预测函数
decisions = predict_sentence_pairs([sentence1, sentence2])
if decisions:
decision = decisions[0]
return {
"sentence1": sentence1,
"sentence2": sentence2,
"should_split": decision['should_split'],
"confidence": decision['confidence'],
"boundary_score": decision['boundary_score'],
"split_reason": decision['split_reason']
}
else:
raise HTTPException(status_code=500, detail="预测失败")
except HTTPException:
raise
except Exception as e:
logger.error(f"句子对预测失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"预测失败: {str(e)}")
def main():
"""启动FastAPI服务"""
logger.info("🚀 启动双路径边界分类器FastAPI服务")
logger.info(f"📝 服务地址: http://0.0.0.0:8888")
logger.info(f"🌐 前端界面: http://0.0.0.0:8888")
logger.info(f"📚 API文档: http://0.0.0.0:8888/docs")
# 启动服务
uvicorn.run(
app,
host="0.0.0.0",
port=8888,
log_level="info",
access_log=True
)
if __name__ == "__main__":
main()