import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import json import os import math import logging import re from datetime import datetime from typing import List, Dict, Any, Optional from collections import Counter from fastapi import FastAPI, HTTPException from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from transformers import BertTokenizer, BertModel import uvicorn # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 全局变量 model = None tokenizer = None device = None class ScaledDotProductAttention(nn.Module): """缩放点积注意力机制""" def __init__(self, d_model, dropout=0.1): super(ScaledDotProductAttention, self).__init__() self.d_model = d_model self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): batch_size, seq_len, d_model = query.size() scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) if mask is not None: mask_value = torch.finfo(scores.dtype).min scores = scores.masked_fill(mask == 0, mask_value) attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) output = torch.matmul(attention_weights, value) return output, attention_weights class DualPathBoundaryClassifier(nn.Module): """双路径边界分类器,完全依靠神经网络学习边界模式""" def __init__(self, model_path, num_labels=2, dropout=0.1, boundary_force_weight=2.0): super(DualPathBoundaryClassifier, self).__init__() self.roberta = BertModel.from_pretrained(model_path) self.config = self.roberta.config self.config.num_labels = num_labels self.scaled_attention = ScaledDotProductAttention( d_model=self.config.hidden_size, dropout=dropout ) self.dropout = nn.Dropout(dropout) # 双路径分类器 self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels) self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels) self.boundary_detector = nn.Linear(self.config.hidden_size, 1) # 边界强制权重 self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight)) def forward(self, input_ids, attention_mask=None, token_type_ids=None): roberta_outputs = self.roberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True ) sequence_output = roberta_outputs.last_hidden_state # 缩放点积注意力增强 enhanced_output, attention_weights = self.scaled_attention( query=sequence_output, key=sequence_output, value=sequence_output, mask=attention_mask.unsqueeze(1) if attention_mask is not None else None ) cls_output = enhanced_output[:, 0, :] cls_output = self.dropout(cls_output) # 双路径分类 regular_logits = self.regular_classifier(cls_output) boundary_logits = self.boundary_classifier(cls_output) # 边界检测 boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1) boundary_score = torch.sigmoid(boundary_logits_raw) # 动态融合 boundary_bias = torch.zeros_like(regular_logits) boundary_bias[:, 1] = boundary_score * self.boundary_force_weight final_logits = regular_logits + boundary_bias return { 'logits': final_logits, 'regular_logits': regular_logits, 'boundary_logits': boundary_logits, 'boundary_score': boundary_score, 'hidden_states': enhanced_output, 'attention_weights': attention_weights } def load_model(): """加载训练好的模型""" global model, tokenizer, device # 检查GPU if torch.cuda.is_available(): device = torch.device('cuda') gpu_name = torch.cuda.get_device_name(0) logger.info(f"✅ 使用GPU: {gpu_name}") else: device = torch.device('cpu') logger.info("⚠️ 使用CPU运行") # 模型路径配置 model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train" original_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model" logger.info(f"📥 加载双路径边界分类器模型...") try: # 先检查训练模型目录是否存在配置文件 config_path = os.path.join(model_path, 'config.json') if os.path.exists(config_path): with open(config_path, 'r', encoding='utf-8') as f: model_config = json.load(f) boundary_force_weight = model_config.get('boundary_force_weight', 2.0) logger.info(f" 🔹 边界强制权重: {boundary_force_weight}") else: boundary_force_weight = 2.0 # 先尝试从训练目录加载tokenizer,如果失败则使用原始目录 try: tokenizer = BertTokenizer.from_pretrained(model_path) logger.info(f" ✅ 从训练目录加载tokenizer成功") except Exception as e: logger.warning(f" ⚠️ 从训练目录加载tokenizer失败: {str(e)}") tokenizer = BertTokenizer.from_pretrained(original_model_path) logger.info(f" ✅ 从原始目录加载tokenizer成功") logger.info(f" 🔹 词汇表大小: {len(tokenizer.vocab)}") # 创建模型实例,使用原始模型路径以避免词汇表不匹配 logger.info(f" 🔧 使用原始模型路径创建模型实例") model = DualPathBoundaryClassifier( model_path=original_model_path, # 强制使用原始模型路径 num_labels=2, dropout=0.1, boundary_force_weight=boundary_force_weight ) # 加载训练好的权重 model_weights_path = os.path.join(model_path, 'pytorch_model.bin') if os.path.exists(model_weights_path): logger.info(f" 📥 加载训练权重...") state_dict = torch.load(model_weights_path, map_location=device) # 尝试加载权重,如果失败则使用更安全的方法 try: model.load_state_dict(state_dict) logger.info(f" ✅ 成功加载完整权重") except RuntimeError as e: if "size mismatch" in str(e): logger.warning(f" ⚠️ 检测到权重尺寸不匹配,使用兼容性加载") # 过滤掉不匹配的权重 model_dict = model.state_dict() filtered_dict = {} for k, v in state_dict.items(): if k in model_dict: if model_dict[k].shape == v.shape: filtered_dict[k] = v else: logger.warning( f" 跳过不匹配的权重: {k} (模型: {model_dict[k].shape}, 检查点: {v.shape})") else: logger.warning(f" 跳过未知权重: {k}") # 加载过滤后的权重 model_dict.update(filtered_dict) model.load_state_dict(model_dict) logger.info(f" ✅ 成功加载兼容权重 ({len(filtered_dict)}/{len(state_dict)} 权重已加载)") else: raise e else: logger.error(f" ❌ 找不到模型权重文件: {model_weights_path}") return False model.to(device) model.eval() total_params = sum(p.numel() for p in model.parameters()) logger.info(f"📊 模型参数: {total_params:,}") # 测试模型是否正常工作 logger.info("🧪 测试模型推理...") test_result = test_model_inference() if test_result: logger.info("✅ 模型推理测试通过") logger.info("🚀 模型加载完成,Ready for service!") return True else: logger.error("❌ 模型推理测试失败") return False except Exception as e: logger.error(f"❌ 模型加载失败: {str(e)}") import traceback traceback.print_exc() return False def test_model_inference(): """测试模型推理是否正常""" try: test_sentences = [ "这是第一个测试句子。", "这是第二个测试句子。" ] with torch.no_grad(): inputs = tokenizer( test_sentences[0], test_sentences[1], truncation=True, padding=True, max_length=512, return_tensors='pt' ) inputs = {k: v.to(device) for k, v in inputs.items()} outputs = model(**inputs) # 检查输出格式 required_keys = ['logits', 'boundary_score'] for key in required_keys: if key not in outputs: logger.error(f"模型输出缺少必要的键: {key}") return False logits = outputs['logits'] boundary_score = outputs['boundary_score'] # 检查输出形状 if logits.shape != torch.Size([1, 2]): logger.error(f"logits形状不正确: {logits.shape}, 期望: [1, 2]") return False if boundary_score.shape != torch.Size([1]): logger.error(f"boundary_score形状不正确: {boundary_score.shape}, 期望: [1]") return False # 检查数值范围 if not (0 <= boundary_score.item() <= 1): logger.error(f"boundary_score超出范围: {boundary_score.item()}") return False logger.info(f" 测试预测: logits={logits.tolist()}, boundary_score={boundary_score.item():.3f}") return True except Exception as e: logger.error(f"模型推理测试异常: {str(e)}") return False def split_text_into_sentences(text: str) -> List[str]: """将文本按句号、感叹号、问号分割成句子""" # 中文句子分割规则 sentence_endings = r'[。!?!?]' sentences = re.split(sentence_endings, text) # 过滤空句子,保留标点符号 result = [] for i, sentence in enumerate(sentences): sentence = sentence.strip() if sentence: # 如果不是最后一个句子,添加标点符号 if i < len(sentences) - 1: # 找到原始标点符号 original_text = text start_pos = 0 for j in range(i): if sentences[j].strip(): start_pos = original_text.find(sentences[j].strip(), start_pos) + len(sentences[j].strip()) # 查找句子后的标点符号 remaining_text = original_text[start_pos:] punctuation_match = re.search(sentence_endings, remaining_text) if punctuation_match: sentence += punctuation_match.group() result.append(sentence) return result def predict_sentence_pairs(sentences: List[str]) -> List[Dict[str, Any]]: """预测相邻句子对是否需要分段""" if len(sentences) < 2: return [] results = [] with torch.no_grad(): for i in range(len(sentences) - 1): sentence1 = sentences[i] sentence2 = sentences[i + 1] # Tokenization inputs = tokenizer( sentence1, sentence2, truncation=True, padding=True, max_length=512, return_tensors='pt' ) # 移动到设备 inputs = {k: v.to(device) for k, v in inputs.items()} # 模型预测 outputs = model(**inputs) # 获取预测结果 logits = outputs['logits'] boundary_score = outputs['boundary_score'] probs = F.softmax(logits, dim=-1) prediction = torch.argmax(logits, dim=-1).item() confidence = torch.max(probs, dim=-1)[0].item() boundary_score_value = boundary_score.item() # 结果 result = { 'sentence1': sentence1, 'sentence2': sentence2, 'prediction': prediction, # 0: 同段落, 1: 不同段落 'confidence': confidence, 'boundary_score': boundary_score_value, 'should_split': prediction == 1, 'split_reason': get_split_reason(prediction, boundary_score_value, confidence) } results.append(result) return results def get_split_reason(prediction: int, boundary_score: float, confidence: float) -> str: """生成分段原因说明""" if prediction == 1: if boundary_score > 0.7: return f"检测到强边界信号 (边界分数: {boundary_score:.3f})" elif confidence > 0.8: return f"语义转换明显 (置信度: {confidence:.3f})" else: return f"建议分段 (置信度: {confidence:.3f})" else: return f"内容连贯,无需分段 (置信度: {confidence:.3f})" def segment_text(text: str) -> Dict[str, Any]: """对完整文本进行分段处理""" # 分割成句子 sentences = split_text_into_sentences(text) if len(sentences) < 2: return { 'original_text': text, 'sentences': sentences, 'segments': [text] if text.strip() else [], 'split_decisions': [], 'total_sentences': len(sentences), 'total_segments': 1 if text.strip() else 0 } # 预测相邻句子对 split_decisions = predict_sentence_pairs(sentences) # 根据预测结果进行分段 segments = [] current_segment = [sentences[0]] for i, decision in enumerate(split_decisions): if decision['should_split']: # 需要分段,结束当前段落 segments.append(''.join(current_segment)) current_segment = [sentences[i + 1]] else: # 不需要分段,继续当前段落 current_segment.append(sentences[i + 1]) # 添加最后一个段落 if current_segment: segments.append(''.join(current_segment)) return { 'original_text': text, 'sentences': sentences, 'segments': segments, 'split_decisions': split_decisions, 'total_sentences': len(sentences), 'total_segments': len(segments) } # FastAPI 应用 app = FastAPI(title="双路径边界分类器文本分段服务", version="1.0.0") # 请求模型 class TextInput(BaseModel): text: str class BatchTextInput(BaseModel): texts: List[str] # 响应模型 class SegmentResult(BaseModel): original_text: str sentences: List[str] segments: List[str] total_sentences: int total_segments: int class DetailedSegmentResult(BaseModel): original_text: str sentences: List[str] segments: List[str] split_decisions: List[Dict[str, Any]] total_sentences: int total_segments: int @app.on_event("startup") async def startup_event(): """启动时加载模型""" logger.info("🚀 启动双路径边界分类器服务...") success = load_model() if not success: logger.error("❌ 模型加载失败,服务无法启动") raise RuntimeError("模型加载失败") logger.info("✅ 服务启动成功!") @app.get("/", response_class=HTMLResponse) async def get_frontend(): """返回前端页面""" html_content = """
智能文本分段服务 - 基于神经网络的广播内容自动分段