|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import json |
|
|
import os |
|
|
import math |
|
|
import logging |
|
|
import re |
|
|
from datetime import datetime |
|
|
from typing import List, Dict, Any, Optional |
|
|
from collections import Counter |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.responses import HTMLResponse, JSONResponse |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from pydantic import BaseModel |
|
|
from transformers import BertTokenizer, BertModel |
|
|
import uvicorn |
|
|
|
|
|
# 设置日志 |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
# 全局变量 |
|
|
model = None |
|
|
tokenizer = None |
|
|
device = None |
|
|
|
|
|
|
|
|
class ScaledDotProductAttention(nn.Module): |
|
|
"""缩放点积注意力机制""" |
|
|
|
|
|
def __init__(self, d_model, dropout=0.1): |
|
|
super(ScaledDotProductAttention, self).__init__() |
|
|
self.d_model = d_model |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, query, key, value, mask=None): |
|
|
batch_size, seq_len, d_model = query.size() |
|
|
|
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) |
|
|
|
|
|
if mask is not None: |
|
|
mask_value = torch.finfo(scores.dtype).min |
|
|
scores = scores.masked_fill(mask == 0, mask_value) |
|
|
|
|
|
attention_weights = F.softmax(scores, dim=-1) |
|
|
attention_weights = self.dropout(attention_weights) |
|
|
|
|
|
output = torch.matmul(attention_weights, value) |
|
|
|
|
|
return output, attention_weights |
|
|
|
|
|
|
|
|
class DualPathBoundaryClassifier(nn.Module): |
|
|
"""双路径边界分类器,完全依靠神经网络学习边界模式""" |
|
|
|
|
|
def __init__(self, model_path, num_labels=2, dropout=0.1, boundary_force_weight=2.0): |
|
|
super(DualPathBoundaryClassifier, self).__init__() |
|
|
|
|
|
self.roberta = BertModel.from_pretrained(model_path) |
|
|
self.config = self.roberta.config |
|
|
self.config.num_labels = num_labels |
|
|
|
|
|
self.scaled_attention = ScaledDotProductAttention( |
|
|
d_model=self.config.hidden_size, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
# 双路径分类器 |
|
|
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels) |
|
|
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels) |
|
|
self.boundary_detector = nn.Linear(self.config.hidden_size, 1) |
|
|
|
|
|
# 边界强制权重 |
|
|
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight)) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None): |
|
|
roberta_outputs = self.roberta( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
sequence_output = roberta_outputs.last_hidden_state |
|
|
|
|
|
# 缩放点积注意力增强 |
|
|
enhanced_output, attention_weights = self.scaled_attention( |
|
|
query=sequence_output, |
|
|
key=sequence_output, |
|
|
value=sequence_output, |
|
|
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None |
|
|
) |
|
|
|
|
|
cls_output = enhanced_output[:, 0, :] |
|
|
cls_output = self.dropout(cls_output) |
|
|
|
|
|
# 双路径分类 |
|
|
regular_logits = self.regular_classifier(cls_output) |
|
|
boundary_logits = self.boundary_classifier(cls_output) |
|
|
|
|
|
# 边界检测 |
|
|
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1) |
|
|
boundary_score = torch.sigmoid(boundary_logits_raw) |
|
|
|
|
|
# 动态融合 |
|
|
boundary_bias = torch.zeros_like(regular_logits) |
|
|
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight |
|
|
|
|
|
final_logits = regular_logits + boundary_bias |
|
|
|
|
|
return { |
|
|
'logits': final_logits, |
|
|
'regular_logits': regular_logits, |
|
|
'boundary_logits': boundary_logits, |
|
|
'boundary_score': boundary_score, |
|
|
'hidden_states': enhanced_output, |
|
|
'attention_weights': attention_weights |
|
|
} |
|
|
|
|
|
|
|
|
def load_model(): |
|
|
"""加载训练好的模型""" |
|
|
global model, tokenizer, device |
|
|
|
|
|
# 检查GPU |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda') |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
logger.info(f"✅ 使用GPU: {gpu_name}") |
|
|
else: |
|
|
device = torch.device('cpu') |
|
|
logger.info("⚠️ 使用CPU运行") |
|
|
|
|
|
# 模型路径配置 |
|
|
model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train" |
|
|
original_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model" |
|
|
|
|
|
logger.info(f"📥 加载双路径边界分类器模型...") |
|
|
|
|
|
try: |
|
|
# 先检查训练模型目录是否存在配置文件 |
|
|
config_path = os.path.join(model_path, 'config.json') |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
model_config = json.load(f) |
|
|
boundary_force_weight = model_config.get('boundary_force_weight', 2.0) |
|
|
logger.info(f" 🔹 边界强制权重: {boundary_force_weight}") |
|
|
else: |
|
|
boundary_force_weight = 2.0 |
|
|
|
|
|
# 先尝试从训练目录加载tokenizer,如果失败则使用原始目录 |
|
|
try: |
|
|
tokenizer = BertTokenizer.from_pretrained(model_path) |
|
|
logger.info(f" ✅ 从训练目录加载tokenizer成功") |
|
|
except Exception as e: |
|
|
logger.warning(f" ⚠️ 从训练目录加载tokenizer失败: {str(e)}") |
|
|
tokenizer = BertTokenizer.from_pretrained(original_model_path) |
|
|
logger.info(f" ✅ 从原始目录加载tokenizer成功") |
|
|
|
|
|
logger.info(f" 🔹 词汇表大小: {len(tokenizer.vocab)}") |
|
|
|
|
|
# 创建模型实例,使用原始模型路径以避免词汇表不匹配 |
|
|
logger.info(f" 🔧 使用原始模型路径创建模型实例") |
|
|
model = DualPathBoundaryClassifier( |
|
|
model_path=original_model_path, # 强制使用原始模型路径 |
|
|
num_labels=2, |
|
|
dropout=0.1, |
|
|
boundary_force_weight=boundary_force_weight |
|
|
) |
|
|
|
|
|
# 加载训练好的权重 |
|
|
model_weights_path = os.path.join(model_path, 'pytorch_model.bin') |
|
|
if os.path.exists(model_weights_path): |
|
|
logger.info(f" 📥 加载训练权重...") |
|
|
state_dict = torch.load(model_weights_path, map_location=device) |
|
|
|
|
|
# 尝试加载权重,如果失败则使用更安全的方法 |
|
|
try: |
|
|
model.load_state_dict(state_dict) |
|
|
logger.info(f" ✅ 成功加载完整权重") |
|
|
except RuntimeError as e: |
|
|
if "size mismatch" in str(e): |
|
|
logger.warning(f" ⚠️ 检测到权重尺寸不匹配,使用兼容性加载") |
|
|
# 过滤掉不匹配的权重 |
|
|
model_dict = model.state_dict() |
|
|
filtered_dict = {} |
|
|
|
|
|
for k, v in state_dict.items(): |
|
|
if k in model_dict: |
|
|
if model_dict[k].shape == v.shape: |
|
|
filtered_dict[k] = v |
|
|
else: |
|
|
logger.warning( |
|
|
f" 跳过不匹配的权重: {k} (模型: {model_dict[k].shape}, 检查点: {v.shape})") |
|
|
else: |
|
|
logger.warning(f" 跳过未知权重: {k}") |
|
|
|
|
|
# 加载过滤后的权重 |
|
|
model_dict.update(filtered_dict) |
|
|
model.load_state_dict(model_dict) |
|
|
logger.info(f" ✅ 成功加载兼容权重 ({len(filtered_dict)}/{len(state_dict)} 权重已加载)") |
|
|
else: |
|
|
raise e |
|
|
else: |
|
|
logger.error(f" ❌ 找不到模型权重文件: {model_weights_path}") |
|
|
return False |
|
|
|
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
logger.info(f"📊 模型参数: {total_params:,}") |
|
|
|
|
|
# 测试模型是否正常工作 |
|
|
logger.info("🧪 测试模型推理...") |
|
|
test_result = test_model_inference() |
|
|
if test_result: |
|
|
logger.info("✅ 模型推理测试通过") |
|
|
logger.info("🚀 模型加载完成,Ready for service!") |
|
|
return True |
|
|
else: |
|
|
logger.error("❌ 模型推理测试失败") |
|
|
return False |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ 模型加载失败: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return False |
|
|
|
|
|
|
|
|
def test_model_inference(): |
|
|
"""测试模型推理是否正常""" |
|
|
try: |
|
|
test_sentences = [ |
|
|
"这是第一个测试句子。", |
|
|
"这是第二个测试句子。" |
|
|
] |
|
|
|
|
|
with torch.no_grad(): |
|
|
inputs = tokenizer( |
|
|
test_sentences[0], |
|
|
test_sentences[1], |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=512, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
outputs = model(**inputs) |
|
|
|
|
|
# 检查输出格式 |
|
|
required_keys = ['logits', 'boundary_score'] |
|
|
for key in required_keys: |
|
|
if key not in outputs: |
|
|
logger.error(f"模型输出缺少必要的键: {key}") |
|
|
return False |
|
|
|
|
|
logits = outputs['logits'] |
|
|
boundary_score = outputs['boundary_score'] |
|
|
|
|
|
# 检查输出形状 |
|
|
if logits.shape != torch.Size([1, 2]): |
|
|
logger.error(f"logits形状不正确: {logits.shape}, 期望: [1, 2]") |
|
|
return False |
|
|
|
|
|
if boundary_score.shape != torch.Size([1]): |
|
|
logger.error(f"boundary_score形状不正确: {boundary_score.shape}, 期望: [1]") |
|
|
return False |
|
|
|
|
|
# 检查数值范围 |
|
|
if not (0 <= boundary_score.item() <= 1): |
|
|
logger.error(f"boundary_score超出范围: {boundary_score.item()}") |
|
|
return False |
|
|
|
|
|
logger.info(f" 测试预测: logits={logits.tolist()}, boundary_score={boundary_score.item():.3f}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"模型推理测试异常: {str(e)}") |
|
|
return False |
|
|
|
|
|
|
|
|
def split_text_into_sentences(text: str) -> List[str]: |
|
|
"""将文本按句号、感叹号、问号分割成句子""" |
|
|
# 中文句子分割规则 |
|
|
sentence_endings = r'[。!?!?]' |
|
|
sentences = re.split(sentence_endings, text) |
|
|
|
|
|
# 过滤空句子,保留标点符号 |
|
|
result = [] |
|
|
for i, sentence in enumerate(sentences): |
|
|
sentence = sentence.strip() |
|
|
if sentence: |
|
|
# 如果不是最后一个句子,添加标点符号 |
|
|
if i < len(sentences) - 1: |
|
|
# 找到原始标点符号 |
|
|
original_text = text |
|
|
start_pos = 0 |
|
|
for j in range(i): |
|
|
if sentences[j].strip(): |
|
|
start_pos = original_text.find(sentences[j].strip(), start_pos) + len(sentences[j].strip()) |
|
|
|
|
|
# 查找句子后的标点符号 |
|
|
remaining_text = original_text[start_pos:] |
|
|
punctuation_match = re.search(sentence_endings, remaining_text) |
|
|
if punctuation_match: |
|
|
sentence += punctuation_match.group() |
|
|
|
|
|
result.append(sentence) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def predict_sentence_pairs(sentences: List[str]) -> List[Dict[str, Any]]: |
|
|
"""预测相邻句子对是否需要分段""" |
|
|
if len(sentences) < 2: |
|
|
return [] |
|
|
|
|
|
results = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(len(sentences) - 1): |
|
|
sentence1 = sentences[i] |
|
|
sentence2 = sentences[i + 1] |
|
|
|
|
|
# Tokenization |
|
|
inputs = tokenizer( |
|
|
sentence1, |
|
|
sentence2, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=512, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
# 移动到设备 |
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
# 模型预测 |
|
|
outputs = model(**inputs) |
|
|
|
|
|
# 获取预测结果 |
|
|
logits = outputs['logits'] |
|
|
boundary_score = outputs['boundary_score'] |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
prediction = torch.argmax(logits, dim=-1).item() |
|
|
confidence = torch.max(probs, dim=-1)[0].item() |
|
|
boundary_score_value = boundary_score.item() |
|
|
|
|
|
# 结果 |
|
|
result = { |
|
|
'sentence1': sentence1, |
|
|
'sentence2': sentence2, |
|
|
'prediction': prediction, # 0: 同段落, 1: 不同段落 |
|
|
'confidence': confidence, |
|
|
'boundary_score': boundary_score_value, |
|
|
'should_split': prediction == 1, |
|
|
'split_reason': get_split_reason(prediction, boundary_score_value, confidence) |
|
|
} |
|
|
|
|
|
results.append(result) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def get_split_reason(prediction: int, boundary_score: float, confidence: float) -> str: |
|
|
"""生成分段原因说明""" |
|
|
if prediction == 1: |
|
|
if boundary_score > 0.7: |
|
|
return f"检测到强边界信号 (边界分数: {boundary_score:.3f})" |
|
|
elif confidence > 0.8: |
|
|
return f"语义转换明显 (置信度: {confidence:.3f})" |
|
|
else: |
|
|
return f"建议分段 (置信度: {confidence:.3f})" |
|
|
else: |
|
|
return f"内容连贯,无需分段 (置信度: {confidence:.3f})" |
|
|
|
|
|
|
|
|
def segment_text(text: str) -> Dict[str, Any]: |
|
|
"""对完整文本进行分段处理""" |
|
|
# 分割成句子 |
|
|
sentences = split_text_into_sentences(text) |
|
|
|
|
|
if len(sentences) < 2: |
|
|
return { |
|
|
'original_text': text, |
|
|
'sentences': sentences, |
|
|
'segments': [text] if text.strip() else [], |
|
|
'split_decisions': [], |
|
|
'total_sentences': len(sentences), |
|
|
'total_segments': 1 if text.strip() else 0 |
|
|
} |
|
|
|
|
|
# 预测相邻句子对 |
|
|
split_decisions = predict_sentence_pairs(sentences) |
|
|
|
|
|
# 根据预测结果进行分段 |
|
|
segments = [] |
|
|
current_segment = [sentences[0]] |
|
|
|
|
|
for i, decision in enumerate(split_decisions): |
|
|
if decision['should_split']: |
|
|
# 需要分段,结束当前段落 |
|
|
segments.append(''.join(current_segment)) |
|
|
current_segment = [sentences[i + 1]] |
|
|
else: |
|
|
# 不需要分段,继续当前段落 |
|
|
current_segment.append(sentences[i + 1]) |
|
|
|
|
|
# 添加最后一个段落 |
|
|
if current_segment: |
|
|
segments.append(''.join(current_segment)) |
|
|
|
|
|
return { |
|
|
'original_text': text, |
|
|
'sentences': sentences, |
|
|
'segments': segments, |
|
|
'split_decisions': split_decisions, |
|
|
'total_sentences': len(sentences), |
|
|
'total_segments': len(segments) |
|
|
} |
|
|
|
|
|
|
|
|
# FastAPI 应用 |
|
|
app = FastAPI(title="双路径边界分类器文本分段服务", version="1.0.0") |
|
|
|
|
|
|
|
|
# 请求模型 |
|
|
class TextInput(BaseModel): |
|
|
text: str |
|
|
|
|
|
|
|
|
class BatchTextInput(BaseModel): |
|
|
texts: List[str] |
|
|
|
|
|
|
|
|
# 响应模型 |
|
|
class SegmentResult(BaseModel): |
|
|
original_text: str |
|
|
sentences: List[str] |
|
|
segments: List[str] |
|
|
total_sentences: int |
|
|
total_segments: int |
|
|
|
|
|
|
|
|
class DetailedSegmentResult(BaseModel): |
|
|
original_text: str |
|
|
sentences: List[str] |
|
|
segments: List[str] |
|
|
split_decisions: List[Dict[str, Any]] |
|
|
total_sentences: int |
|
|
total_segments: int |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""启动时加载模型""" |
|
|
logger.info("🚀 启动双路径边界分类器服务...") |
|
|
success = load_model() |
|
|
if not success: |
|
|
logger.error("❌ 模型加载失败,服务无法启动") |
|
|
raise RuntimeError("模型加载失败") |
|
|
logger.info("✅ 服务启动成功!") |
|
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
|
async def get_frontend(): |
|
|
"""返回前端页面""" |
|
|
html_content = """ |
|
|
<!DOCTYPE html> |
|
|
<html lang="zh-CN"> |
|
|
<head> |
|
|
<meta charset="UTF-8"> |
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
|
<title>双路径边界分类器 - 文本分段服务</title> |
|
|
<style> |
|
|
body { |
|
|
font-family: 'Arial', sans-serif; |
|
|
margin: 0; |
|
|
padding: 20px; |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
min-height: 100vh; |
|
|
} |
|
|
|
|
|
.container { |
|
|
max-width: 1200px; |
|
|
margin: 0 auto; |
|
|
background: white; |
|
|
border-radius: 15px; |
|
|
box-shadow: 0 10px 30px rgba(0,0,0,0.2); |
|
|
overflow: hidden; |
|
|
} |
|
|
|
|
|
.header { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
color: white; |
|
|
padding: 30px; |
|
|
text-align: center; |
|
|
} |
|
|
|
|
|
.header h1 { |
|
|
margin: 0; |
|
|
font-size: 2.5em; |
|
|
font-weight: bold; |
|
|
} |
|
|
|
|
|
.header p { |
|
|
margin: 10px 0 0 0; |
|
|
font-size: 1.2em; |
|
|
opacity: 0.9; |
|
|
} |
|
|
|
|
|
.content { |
|
|
padding: 30px; |
|
|
} |
|
|
|
|
|
.input-section { |
|
|
margin-bottom: 30px; |
|
|
} |
|
|
|
|
|
.input-section label { |
|
|
display: block; |
|
|
margin-bottom: 10px; |
|
|
font-weight: bold; |
|
|
color: #333; |
|
|
font-size: 1.1em; |
|
|
} |
|
|
|
|
|
.input-section textarea { |
|
|
width: 100%; |
|
|
height: 200px; |
|
|
padding: 15px; |
|
|
border: 2px solid #ddd; |
|
|
border-radius: 10px; |
|
|
font-size: 16px; |
|
|
font-family: 'Arial', sans-serif; |
|
|
resize: vertical; |
|
|
box-sizing: border-box; |
|
|
} |
|
|
|
|
|
.input-section textarea:focus { |
|
|
outline: none; |
|
|
border-color: #667eea; |
|
|
box-shadow: 0 0 10px rgba(102, 126, 234, 0.3); |
|
|
} |
|
|
|
|
|
.button-group { |
|
|
display: flex; |
|
|
gap: 15px; |
|
|
margin-top: 20px; |
|
|
} |
|
|
|
|
|
.btn { |
|
|
padding: 12px 25px; |
|
|
border: none; |
|
|
border-radius: 8px; |
|
|
font-size: 16px; |
|
|
font-weight: bold; |
|
|
cursor: pointer; |
|
|
transition: all 0.3s ease; |
|
|
} |
|
|
|
|
|
.btn-primary { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
color: white; |
|
|
} |
|
|
|
|
|
.btn-primary:hover { |
|
|
transform: translateY(-2px); |
|
|
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); |
|
|
} |
|
|
|
|
|
.btn-secondary { |
|
|
background: #f8f9fa; |
|
|
color: #333; |
|
|
border: 2px solid #ddd; |
|
|
} |
|
|
|
|
|
.btn-secondary:hover { |
|
|
background: #e9ecef; |
|
|
} |
|
|
|
|
|
.btn:disabled { |
|
|
opacity: 0.6; |
|
|
cursor: not-allowed; |
|
|
transform: none !important; |
|
|
} |
|
|
|
|
|
.loading { |
|
|
display: none; |
|
|
text-align: center; |
|
|
padding: 20px; |
|
|
color: #667eea; |
|
|
font-weight: bold; |
|
|
} |
|
|
|
|
|
.loading::after { |
|
|
content: ""; |
|
|
display: inline-block; |
|
|
margin-left: 10px; |
|
|
width: 20px; |
|
|
height: 20px; |
|
|
border: 3px solid #f3f3f3; |
|
|
border-top: 3px solid #667eea; |
|
|
border-radius: 50%; |
|
|
animation: spin 1s linear infinite; |
|
|
} |
|
|
|
|
|
@keyframes spin { |
|
|
0% { transform: rotate(0deg); } |
|
|
100% { transform: rotate(360deg); } |
|
|
} |
|
|
|
|
|
.result-section { |
|
|
margin-top: 30px; |
|
|
display: none; |
|
|
} |
|
|
|
|
|
.result-header { |
|
|
background: #f8f9fa; |
|
|
padding: 15px; |
|
|
border-radius: 10px; |
|
|
margin-bottom: 20px; |
|
|
border-left: 5px solid #667eea; |
|
|
} |
|
|
|
|
|
.segment { |
|
|
background: #f8f9fa; |
|
|
padding: 15px; |
|
|
margin: 10px 0; |
|
|
border-radius: 8px; |
|
|
border-left: 4px solid #28a745; |
|
|
position: relative; |
|
|
} |
|
|
|
|
|
.segment-header { |
|
|
font-weight: bold; |
|
|
color: #28a745; |
|
|
margin-bottom: 8px; |
|
|
} |
|
|
|
|
|
.segment-content { |
|
|
line-height: 1.6; |
|
|
color: #333; |
|
|
} |
|
|
|
|
|
.split-decision { |
|
|
background: white; |
|
|
padding: 10px 15px; |
|
|
margin: 8px 0; |
|
|
border-radius: 6px; |
|
|
border-left: 3px solid #ffc107; |
|
|
font-size: 14px; |
|
|
} |
|
|
|
|
|
.split-decision.split { |
|
|
border-left-color: #dc3545; |
|
|
background: #fff5f5; |
|
|
} |
|
|
|
|
|
.split-decision.no-split { |
|
|
border-left-color: #28a745; |
|
|
background: #f5fff5; |
|
|
} |
|
|
|
|
|
.stats { |
|
|
display: grid; |
|
|
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); |
|
|
gap: 15px; |
|
|
margin: 20px 0; |
|
|
} |
|
|
|
|
|
.stat-card { |
|
|
background: white; |
|
|
padding: 20px; |
|
|
border-radius: 10px; |
|
|
text-align: center; |
|
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1); |
|
|
} |
|
|
|
|
|
.stat-number { |
|
|
font-size: 2em; |
|
|
font-weight: bold; |
|
|
color: #667eea; |
|
|
} |
|
|
|
|
|
.stat-label { |
|
|
color: #666; |
|
|
margin-top: 5px; |
|
|
} |
|
|
|
|
|
.demo-buttons { |
|
|
display: flex; |
|
|
gap: 10px; |
|
|
margin-top: 15px; |
|
|
flex-wrap: wrap; |
|
|
} |
|
|
|
|
|
.demo-btn { |
|
|
padding: 8px 15px; |
|
|
background: #e9ecef; |
|
|
border: 1px solid #ddd; |
|
|
border-radius: 5px; |
|
|
cursor: pointer; |
|
|
font-size: 14px; |
|
|
transition: all 0.2s ease; |
|
|
} |
|
|
|
|
|
.demo-btn:hover { |
|
|
background: #dee2e6; |
|
|
border-color: #667eea; |
|
|
} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<div class="container"> |
|
|
<div class="header"> |
|
|
<h1>🤖 双路径边界分类器</h1> |
|
|
<p>智能文本分段服务 - 基于神经网络的广播内容自动分段</p> |
|
|
</div> |
|
|
|
|
|
<div class="content"> |
|
|
<div class="input-section"> |
|
|
<label for="textInput">📝 请输入需要分段的文本内容:</label> |
|
|
<textarea id="textInput" placeholder="请输入广播内容或其他需要分段的文本... 示例: 小月提醒大家,不认识的野菜一定不要去采,避免因误食发生过敏或者中毒。好了,健康快车赶快上车,我是小月,我们下期再会。下面即将收听到的是普法档案。欢迎收听普法档案,我是沐白。"></textarea> |
|
|
|
|
|
<div class="demo-buttons"> |
|
|
<button class="demo-btn" onclick="loadDemo(1)">📻 广播节目示例</button> |
|
|
<button class="demo-btn" onclick="loadDemo(2)">📰 新闻内容示例</button> |
|
|
<button class="demo-btn" onclick="loadDemo(3)">📚 教育内容示例</button> |
|
|
<button class="demo-btn" onclick="clearText()">🗑️ 清空</button> |
|
|
</div> |
|
|
|
|
|
<div class="button-group"> |
|
|
<button class="btn btn-primary" onclick="processText()">🚀 开始分段</button> |
|
|
<button class="btn btn-secondary" onclick="clearResults()">🔄 清除结果</button> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<div class="loading" id="loading"> |
|
|
正在分析文本,请稍候... |
|
|
</div> |
|
|
|
|
|
<div class="result-section" id="resultSection"> |
|
|
<div class="result-header"> |
|
|
<h3>📊 分段结果</h3> |
|
|
<div class="stats" id="stats"></div> |
|
|
</div> |
|
|
|
|
|
<div id="segments"></div> |
|
|
|
|
|
<details style="margin-top: 20px;"> |
|
|
<summary style="cursor: pointer; font-weight: bold; color: #667eea;">🔍 查看详细分析过程</summary> |
|
|
<div id="detailedAnalysis" style="margin-top: 15px;"></div> |
|
|
</details> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<script> |
|
|
const demoTexts = { |
|
|
1: `小月提醒大家,不认识的野菜一定不要去采,避免因误食发生过敏或者中毒。食用野菜前最好留存少许的野菜或者先拍照,一旦发生不适要停止食用,立即催吐,然后携带剩余野菜呕吐物或者之前拍的照片及时就医。好了,健康快车赶快上车,我是小月,我们下期再会。下面即将收听到的是普法档案。欢迎收听普法档案,我是沐白。今天我们来讨论一个重要的法律问题。这个问题涉及到合同纠纷的处理方式。感谢大家收听今天的节目内容。接下来为您播放轻音乐时光。`, |
|
|
|
|
|
2: `今日上午,市政府召开新闻发布会,宣布了新的城市规划方案。该方案将重点发展科技创新产业,预计投资总额达到500亿元。据了解,新规划将涵盖教育、医疗、交通等多个领域。以上就是今天的新闻内容。现在为大家播放天气预报。明天将是晴朗的一天,气温在15到25度之间。请大家注意适当增减衣物。`, |
|
|
|
|
|
3: `在学习语言的过程中,我们需要掌握几个重要的原则。首先是要多听多说,培养语感。其次是要注意语法的正确性,避免常见错误。最后是要多阅读,扩大词汇量。今天的语言学习课程就到这里。下面请收听音乐欣赏节目。今天为大家介绍的是古典音乐的魅力。音乐能够陶冶情操,提升审美能力。` |
|
|
}; |
|
|
|
|
|
function loadDemo(num) { |
|
|
document.getElementById('textInput').value = demoTexts[num]; |
|
|
} |
|
|
|
|
|
function clearText() { |
|
|
document.getElementById('textInput').value = ''; |
|
|
} |
|
|
|
|
|
function clearResults() { |
|
|
document.getElementById('resultSection').style.display = 'none'; |
|
|
} |
|
|
|
|
|
async function processText() { |
|
|
const text = document.getElementById('textInput').value.trim(); |
|
|
|
|
|
if (!text) { |
|
|
alert('请输入要分段的文本内容!'); |
|
|
return; |
|
|
} |
|
|
|
|
|
// 显示加载状态 |
|
|
document.getElementById('loading').style.display = 'block'; |
|
|
document.getElementById('resultSection').style.display = 'none'; |
|
|
|
|
|
try { |
|
|
const response = await fetch('/segment', { |
|
|
method: 'POST', |
|
|
headers: { |
|
|
'Content-Type': 'application/json', |
|
|
}, |
|
|
body: JSON.stringify({ text: text }) |
|
|
}); |
|
|
|
|
|
if (!response.ok) { |
|
|
throw new Error('网络请求失败'); |
|
|
} |
|
|
|
|
|
const result = await response.json(); |
|
|
displayResults(result); |
|
|
|
|
|
} catch (error) { |
|
|
alert('处理失败:' + error.message); |
|
|
} finally { |
|
|
document.getElementById('loading').style.display = 'none'; |
|
|
} |
|
|
} |
|
|
|
|
|
function displayResults(result) { |
|
|
// 显示统计信息 |
|
|
const statsHtml = ` |
|
|
<div class="stat-card"> |
|
|
<div class="stat-number">${result.total_sentences}</div> |
|
|
<div class="stat-label">总句子数</div> |
|
|
</div> |
|
|
<div class="stat-card"> |
|
|
<div class="stat-number">${result.total_segments}</div> |
|
|
<div class="stat-label">分段数量</div> |
|
|
</div> |
|
|
<div class="stat-card"> |
|
|
<div class="stat-number">${(result.split_decisions || []).filter(d => d.should_split).length}</div> |
|
|
<div class="stat-label">分段点数</div> |
|
|
</div> |
|
|
`; |
|
|
document.getElementById('stats').innerHTML = statsHtml; |
|
|
|
|
|
// 显示分段结果 |
|
|
const segmentsHtml = result.segments.map((segment, index) => ` |
|
|
<div class="segment"> |
|
|
<div class="segment-header">段落 ${index + 1}</div> |
|
|
<div class="segment-content">${segment}</div> |
|
|
</div> |
|
|
`).join(''); |
|
|
document.getElementById('segments').innerHTML = segmentsHtml; |
|
|
|
|
|
// 显示详细分析 |
|
|
if (result.split_decisions && result.split_decisions.length > 0) { |
|
|
const analysisHtml = result.split_decisions.map((decision, index) => ` |
|
|
<div class="split-decision ${decision.should_split ? 'split' : 'no-split'}"> |
|
|
<strong>句子对 ${index + 1}:</strong><br> |
|
|
<div style="margin: 5px 0;"> |
|
|
<strong>句子1:</strong> ${decision.sentence1}<br> |
|
|
<strong>句子2:</strong> ${decision.sentence2} |
|
|
</div> |
|
|
<div style="margin: 5px 0;"> |
|
|
<strong>决策:</strong> ${decision.should_split ? '🔴 需要分段' : '🟢 无需分段'}<br> |
|
|
<strong>置信度:</strong> ${(decision.confidence * 100).toFixed(1)}%<br> |
|
|
<strong>边界分数:</strong> ${(decision.boundary_score * 100).toFixed(1)}%<br> |
|
|
<strong>原因:</strong> ${decision.split_reason} |
|
|
</div> |
|
|
</div> |
|
|
`).join(''); |
|
|
document.getElementById('detailedAnalysis').innerHTML = analysisHtml; |
|
|
} |
|
|
|
|
|
document.getElementById('resultSection').style.display = 'block'; |
|
|
document.getElementById('resultSection').scrollIntoView({ behavior: 'smooth' }); |
|
|
} |
|
|
</script> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
return HTMLResponse(content=html_content) |
|
|
|
|
|
|
|
|
@app.post("/segment", response_model=DetailedSegmentResult) |
|
|
async def segment_text_api(input_data: TextInput): |
|
|
"""文本分段API""" |
|
|
try: |
|
|
if not input_data.text.strip(): |
|
|
raise HTTPException(status_code=400, detail="输入文本不能为空") |
|
|
|
|
|
result = segment_text(input_data.text) |
|
|
return DetailedSegmentResult(**result) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"文本分段处理失败: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/segment/batch") |
|
|
async def segment_batch_api(input_data: BatchTextInput): |
|
|
"""批量文本分段API""" |
|
|
try: |
|
|
if not input_data.texts: |
|
|
raise HTTPException(status_code=400, detail="输入文本列表不能为空") |
|
|
|
|
|
results = [] |
|
|
for i, text in enumerate(input_data.texts): |
|
|
if text.strip(): |
|
|
result = segment_text(text) |
|
|
result['text_index'] = i |
|
|
results.append(result) |
|
|
|
|
|
return {"results": results, "total_processed": len(results)} |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"批量文本分段处理失败: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") |
|
|
|
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
"""健康检查API""" |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model_loaded": model is not None, |
|
|
"device": str(device) if device else "unknown", |
|
|
"timestamp": datetime.now().isoformat() |
|
|
} |
|
|
|
|
|
|
|
|
@app.get("/model/info") |
|
|
async def model_info(): |
|
|
"""模型信息API""" |
|
|
if model is None: |
|
|
raise HTTPException(status_code=503, detail="模型未加载") |
|
|
|
|
|
try: |
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
return { |
|
|
"model_type": "DualPathBoundaryClassifier", |
|
|
"total_parameters": total_params, |
|
|
"device": str(device), |
|
|
"boundary_force_weight": float(model.boundary_force_weight.data), |
|
|
"vocab_size": len(tokenizer.vocab), |
|
|
"max_length": 512 |
|
|
} |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"获取模型信息失败: {str(e)}") |
|
|
|
|
|
|
|
|
@app.post("/predict/pair") |
|
|
async def predict_sentence_pair(sentence1: str, sentence2: str): |
|
|
"""预测单个句子对是否需要分段""" |
|
|
try: |
|
|
if not sentence1.strip() or not sentence2.strip(): |
|
|
raise HTTPException(status_code=400, detail="句子不能为空") |
|
|
|
|
|
# 使用现有的预测函数 |
|
|
decisions = predict_sentence_pairs([sentence1, sentence2]) |
|
|
|
|
|
if decisions: |
|
|
decision = decisions[0] |
|
|
return { |
|
|
"sentence1": sentence1, |
|
|
"sentence2": sentence2, |
|
|
"should_split": decision['should_split'], |
|
|
"confidence": decision['confidence'], |
|
|
"boundary_score": decision['boundary_score'], |
|
|
"split_reason": decision['split_reason'] |
|
|
} |
|
|
else: |
|
|
raise HTTPException(status_code=500, detail="预测失败") |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"句子对预测失败: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"预测失败: {str(e)}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""启动FastAPI服务""" |
|
|
logger.info("🚀 启动双路径边界分类器FastAPI服务") |
|
|
logger.info(f"📝 服务地址: http://0.0.0.0:8888") |
|
|
logger.info(f"🌐 前端界面: http://0.0.0.0:8888") |
|
|
logger.info(f"📚 API文档: http://0.0.0.0:8888/docs") |
|
|
|
|
|
# 启动服务 |
|
|
uvicorn.run( |
|
|
app, |
|
|
host="0.0.0.0", |
|
|
port=8888, |
|
|
log_level="info", |
|
|
access_log=True |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |