import torch import torch.nn as nn import torch.nn.functional as F import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, precision_recall_fscore_support from transformers import BertTokenizer, BertModel from torch.utils.data import Dataset, DataLoader import logging import os import json from datetime import datetime import math # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 设置matplotlib英文字体 plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False def check_gpu_availability(): """检查GPU可用性""" if torch.cuda.is_available(): gpu_count = torch.cuda.device_count() gpu_name = torch.cuda.get_device_name(0) gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 logger.info(f"✅ GPU Available!") logger.info(f" 🔹 GPU Count: {gpu_count}") logger.info(f" 🔹 GPU Model: {gpu_name}") logger.info(f" 🔹 GPU Memory: {gpu_memory:.1f} GB") torch.cuda.empty_cache() torch.backends.cudnn.benchmark = True return True, gpu_memory else: logger.info("⚠️ GPU not available, using CPU") return False, 0 class FocalLoss(nn.Module): """Focal Loss for handling class imbalance""" def __init__(self, alpha=None, gamma=3.0, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) if self.alpha is not None: if self.alpha.type() != inputs.data.type(): self.alpha = self.alpha.type_as(inputs.data) at = self.alpha.gather(0, targets.data.view(-1)) ce_loss = ce_loss * at focal_weight = (1 - pt) ** self.gamma focal_loss = focal_weight * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss class ScaledDotProductAttention(nn.Module): """Scaled Dot Product Attention""" def __init__(self, d_model, dropout=0.1): super(ScaledDotProductAttention, self).__init__() self.d_model = d_model self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): batch_size, seq_len, d_model = query.size() scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) if mask is not None: mask_value = torch.finfo(scores.dtype).min scores = scores.masked_fill(mask == 0, mask_value) attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) output = torch.matmul(attention_weights, value) return output, attention_weights class DualPathBoundaryClassifier(nn.Module): """Dual Path Boundary Classifier with Pure Neural Network Learning""" def __init__(self, model_path, num_labels=2, dropout=0.1, focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0): super(DualPathBoundaryClassifier, self).__init__() self.roberta = BertModel.from_pretrained(model_path) self.config = self.roberta.config self.config.num_labels = num_labels self.scaled_attention = ScaledDotProductAttention( d_model=self.config.hidden_size, dropout=dropout ) self.dropout = nn.Dropout(dropout) # Dual path classifiers self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels) self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels) self.boundary_detector = nn.Linear(self.config.hidden_size, 1) # Boundary force weight self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight)) self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) self._init_weights() self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma def _init_weights(self): """Initialize weights for new layers""" nn.init.normal_(self.regular_classifier.weight, std=0.02) nn.init.zeros_(self.regular_classifier.bias) nn.init.normal_(self.boundary_classifier.weight, std=0.02) nn.init.zeros_(self.boundary_classifier.bias) nn.init.normal_(self.boundary_detector.weight, std=0.02) nn.init.zeros_(self.boundary_detector.bias) def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None): roberta_outputs = self.roberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True ) sequence_output = roberta_outputs.last_hidden_state # Enhanced with scaled attention enhanced_output, attention_weights = self.scaled_attention( query=sequence_output, key=sequence_output, value=sequence_output, mask=attention_mask.unsqueeze(1) if attention_mask is not None else None ) cls_output = enhanced_output[:, 0, :] cls_output = self.dropout(cls_output) # Dual path classification regular_logits = self.regular_classifier(cls_output) boundary_logits = self.boundary_classifier(cls_output) # Boundary detection boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1) boundary_score = torch.sigmoid(boundary_logits_raw) # Dynamic fusion boundary_bias = torch.zeros_like(regular_logits) boundary_bias[:, 1] = boundary_score * self.boundary_force_weight final_logits = regular_logits + boundary_bias loss = None if labels is not None: regular_loss = self.focal_loss(regular_logits, labels) boundary_loss = self.focal_loss(boundary_logits, labels) final_loss = self.focal_loss(final_logits, labels) boundary_labels = self._generate_boundary_labels(labels) detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels) total_loss = (0.4 * final_loss + 0.3 * regular_loss + 0.2 * boundary_loss + 0.1 * detection_loss) loss = total_loss return { 'loss': loss, 'logits': final_logits, 'regular_logits': regular_logits, 'boundary_logits': boundary_logits, 'boundary_score': boundary_score, 'hidden_states': enhanced_output, 'attention_weights': attention_weights } def _generate_boundary_labels(self, labels): """Generate heuristic labels for boundary detection""" boundary_labels = labels.float() noise = torch.rand_like(boundary_labels) * 0.1 boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0) return boundary_labels class SentencePairTestDataset(Dataset): """Sentence pair test dataset""" def __init__(self, data, tokenizer, max_length=384): self.data = data self.tokenizer = tokenizer self.max_length = max_length # Filter valid data self.valid_data = [item for item in data if item['label'] in [0, 1]] logger.info(f"Original test data: {len(data)} items, Valid data: {len(self.valid_data)} items") self.sentence1_list = [item['sentence1'] for item in self.valid_data] self.sentence2_list = [item['sentence2'] for item in self.valid_data] self.labels = [item['label'] for item in self.valid_data] def __len__(self): return len(self.valid_data) def __getitem__(self, idx): sentence1 = str(self.sentence1_list[idx]) sentence2 = str(self.sentence2_list[idx]) label = self.labels[idx] encoding = self.tokenizer( sentence1, sentence2, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long), 'sentence1': sentence1, 'sentence2': sentence2 } def load_trained_model(model_path, tokenizer_path): """Load the trained dual path model""" logger.info(f"Loading trained model from: {model_path}") # Load tokenizer tokenizer = BertTokenizer.from_pretrained(tokenizer_path) # Load model configuration config_path = os.path.join(model_path, 'config.json') if os.path.exists(config_path): with open(config_path, 'r', encoding='utf-8') as f: model_config = json.load(f) logger.info(f"Model config loaded: {model_config.get('model_type', 'Unknown')}") # Initialize model model = DualPathBoundaryClassifier( model_path=tokenizer_path, # Use original model path for RoBERTa base num_labels=2, dropout=0.1, focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0 ) # Load trained weights model_weights_path = os.path.join(model_path, 'pytorch_model.bin') if os.path.exists(model_weights_path): model.load_state_dict(torch.load(model_weights_path, map_location='cpu')) logger.info("✅ Trained model weights loaded successfully") else: raise FileNotFoundError(f"Model weights not found at: {model_weights_path}") return model, tokenizer def evaluate_model(model, test_dataloader, device, output_dir): """Evaluate the model on test dataset""" model.eval() all_predictions = [] all_labels = [] all_probabilities = [] all_boundary_scores = [] logger.info("🔍 Starting model evaluation...") with torch.no_grad(): for batch_idx, batch in enumerate(test_dataloader): input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = model( input_ids=input_ids, attention_mask=attention_mask ) logits = outputs['logits'] boundary_scores = outputs['boundary_score'] # Get predictions probabilities = torch.softmax(logits, dim=-1) predictions = torch.argmax(logits, dim=-1) all_predictions.extend(predictions.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) all_probabilities.extend(probabilities.cpu().numpy()) all_boundary_scores.extend(boundary_scores.cpu().numpy()) if (batch_idx + 1) % 100 == 0: logger.info(f" Processed {batch_idx + 1} batches...") # Convert to numpy arrays all_predictions = np.array(all_predictions) all_labels = np.array(all_labels) all_probabilities = np.array(all_probabilities) all_boundary_scores = np.array(all_boundary_scores) # Calculate metrics accuracy = accuracy_score(all_labels, all_predictions) precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_predictions, average=None) # Generate detailed classification report class_report = classification_report(all_labels, all_predictions, target_names=['Same Paragraph (0)', 'Different Paragraph (1)'], output_dict=True) # Generate confusion matrix cm = confusion_matrix(all_labels, all_predictions) logger.info("📊 Test Results:") logger.info(f" Overall Accuracy: {accuracy:.4f}") logger.info( f" Class 0 (Same Paragraph) - Precision: {precision[0]:.4f}, Recall: {recall[0]:.4f}, F1: {f1[0]:.4f}") logger.info( f" Class 1 (Different Paragraph) - Precision: {precision[1]:.4f}, Recall: {recall[1]:.4f}, F1: {f1[1]:.4f}") # Save results results = { 'overall_accuracy': float(accuracy), 'class_metrics': { 'class_0_same_paragraph': { 'precision': float(precision[0]), 'recall': float(recall[0]), 'f1_score': float(f1[0]), 'support': int(support[0]) }, 'class_1_different_paragraph': { 'precision': float(precision[1]), 'recall': float(recall[1]), 'f1_score': float(f1[1]), 'support': int(support[1]) } }, 'confusion_matrix': cm.tolist(), 'classification_report': class_report, 'test_samples_count': len(all_predictions), 'boundary_score_stats': { 'mean': float(np.mean(all_boundary_scores)), 'std': float(np.std(all_boundary_scores)), 'min': float(np.min(all_boundary_scores)), 'max': float(np.max(all_boundary_scores)) } } return results, cm, all_predictions, all_labels, all_probabilities, all_boundary_scores def plot_confusion_matrix(cm, output_dir, model_name="Dual Path Boundary Classifier"): """Plot and save confusion matrix""" plt.figure(figsize=(10, 8)) # Create heatmap sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], cbar_kws={'label': 'Number of Samples'}) plt.title(f'Confusion Matrix - {model_name}\nTest Dataset Evaluation', fontsize=16, fontweight='bold', pad=20) plt.xlabel('Predicted Label', fontsize=14, fontweight='bold') plt.ylabel('True Label', fontsize=14, fontweight='bold') # Add accuracy text accuracy = np.trace(cm) / np.sum(cm) plt.text(0.5, -0.15, f'Overall Accuracy: {accuracy:.4f}', ha='center', transform=plt.gca().transAxes, fontsize=12, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7)) # Add sample counts total_samples = np.sum(cm) plt.text(0.5, -0.25, f'Total Test Samples: {total_samples}', ha='center', transform=plt.gca().transAxes, fontsize=10) plt.tight_layout() # Save plot save_path = os.path.join(output_dir, 'confusion_matrix_test_results.png') plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() logger.info(f"📊 Confusion matrix saved: {save_path}") return save_path def plot_class_distribution(results, output_dir): """Plot class distribution and performance metrics""" fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) # Class distribution class_0_count = results['class_metrics']['class_0_same_paragraph']['support'] class_1_count = results['class_metrics']['class_1_different_paragraph']['support'] ax1.bar(['Same Paragraph (0)', 'Different Paragraph (1)'], [class_0_count, class_1_count], color=['skyblue', 'lightcoral']) ax1.set_title('Test Dataset Class Distribution', fontweight='bold') ax1.set_ylabel('Number of Samples') # Add count labels on bars ax1.text(0, class_0_count + max(class_0_count, class_1_count) * 0.01, str(class_0_count), ha='center', fontweight='bold') ax1.text(1, class_1_count + max(class_0_count, class_1_count) * 0.01, str(class_1_count), ha='center', fontweight='bold') # Precision comparison precision_0 = results['class_metrics']['class_0_same_paragraph']['precision'] precision_1 = results['class_metrics']['class_1_different_paragraph']['precision'] ax2.bar(['Same Paragraph (0)', 'Different Paragraph (1)'], [precision_0, precision_1], color=['lightgreen', 'orange']) ax2.set_title('Precision by Class', fontweight='bold') ax2.set_ylabel('Precision Score') ax2.set_ylim(0, 1) # Recall comparison recall_0 = results['class_metrics']['class_0_same_paragraph']['recall'] recall_1 = results['class_metrics']['class_1_different_paragraph']['recall'] ax3.bar(['Same Paragraph (0)', 'Different Paragraph (1)'], [recall_0, recall_1], color=['lightcyan', 'plum']) ax3.set_title('Recall by Class', fontweight='bold') ax3.set_ylabel('Recall Score') ax3.set_ylim(0, 1) # F1-Score comparison f1_0 = results['class_metrics']['class_0_same_paragraph']['f1_score'] f1_1 = results['class_metrics']['class_1_different_paragraph']['f1_score'] ax4.bar(['Same Paragraph (0)', 'Different Paragraph (1)'], [f1_0, f1_1], color=['gold', 'mediumpurple']) ax4.set_title('F1-Score by Class', fontweight='bold') ax4.set_ylabel('F1-Score') ax4.set_ylim(0, 1) plt.suptitle('Dual Path Boundary Classifier - Test Performance Analysis', fontsize=16, fontweight='bold') plt.tight_layout() # Save plot save_path = os.path.join(output_dir, 'class_performance_analysis.png') plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() logger.info(f"📊 Class performance analysis saved: {save_path}") return save_path def main(): """Main function""" logger.info("=" * 80) logger.info("🚀 Dual Path Boundary Classifier - Test Evaluation") logger.info("=" * 80) # Configuration original_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model" trained_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train" test_file = r"D:\workstation\AI标注\数据清洗+json\test_dataset.json" output_dir = r"D:\workstation\AI标注\test" # Ensure output directory exists os.makedirs(output_dir, exist_ok=True) # Check GPU availability gpu_available, gpu_memory = check_gpu_availability() device = torch.device('cuda' if gpu_available else 'cpu') logger.info(f"📋 Test Configuration:") logger.info(f" 🔹 Original Model: {original_model_path}") logger.info(f" 🔹 Trained Model: {trained_model_path}") logger.info(f" 🔹 Test Dataset: {test_file}") logger.info(f" 🔹 Output Directory: {output_dir}") logger.info(f" 🔹 Max Length: 384 tokens") logger.info(f" 🔹 Device: {device}") try: # Load test data logger.info("📥 Loading test dataset...") with open(test_file, 'r', encoding='utf-8') as f: test_data = json.load(f) logger.info(f" Loaded {len(test_data)} test samples") # Load trained model model, tokenizer = load_trained_model(trained_model_path, original_model_path) model = model.to(device) # Create test dataset test_dataset = SentencePairTestDataset(test_data, tokenizer, max_length=384) test_dataloader = DataLoader( test_dataset, batch_size=32, # Optimized for RTX 4060 shuffle=False, num_workers=4, pin_memory=True if gpu_available else False ) logger.info(f" Test dataset size: {len(test_dataset)}") logger.info(f" Batch size: 32") # Evaluate model start_time = datetime.now() results, cm, predictions, labels, probabilities, boundary_scores = evaluate_model( model, test_dataloader, device, output_dir ) end_time = datetime.now() evaluation_time = end_time - start_time logger.info(f"⏱️ Evaluation completed in: {evaluation_time}") # Generate visualizations logger.info("📊 Generating visualizations...") # Plot confusion matrix cm_path = plot_confusion_matrix(cm, output_dir) # Plot class performance analysis perf_path = plot_class_distribution(results, output_dir) # Save detailed results results['evaluation_info'] = { 'evaluation_time': str(evaluation_time), 'device_used': str(device), 'model_type': 'DualPathBoundaryClassifier', 'max_length': 384, 'batch_size': 32, 'test_file': test_file, 'trained_model_path': trained_model_path } results_path = os.path.join(output_dir, 'test_results_detailed.json') with open(results_path, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) # Save predictions predictions_data = [] for i in range(len(predictions)): predictions_data.append({ 'index': i, 'sentence1': test_dataset[i]['sentence1'], 'sentence2': test_dataset[i]['sentence2'], 'true_label': int(labels[i]), 'predicted_label': int(predictions[i]), 'probability_class_0': float(probabilities[i][0]), 'probability_class_1': float(probabilities[i][1]), 'boundary_score': float(boundary_scores[i]), 'correct': bool(labels[i] == predictions[i]) }) predictions_path = os.path.join(output_dir, 'detailed_predictions.json') with open(predictions_path, 'w', encoding='utf-8') as f: json.dump(predictions_data, f, ensure_ascii=False, indent=2) # Generate summary report summary = { 'model_info': { 'model_type': 'Dual Path Boundary Classifier', 'base_model': 'Chinese-RoBERTa-WWM-Ext', 'max_length': 384, 'trained_model_path': trained_model_path }, 'test_results': { 'overall_accuracy': results['overall_accuracy'], 'total_samples': len(predictions), 'correct_predictions': int(np.sum(labels == predictions)), 'incorrect_predictions': int(np.sum(labels != predictions)) }, 'class_performance': results['class_metrics'], 'boundary_detection': results['boundary_score_stats'], 'files_generated': [ 'test_results_detailed.json', 'detailed_predictions.json', 'confusion_matrix_test_results.png', 'class_performance_analysis.png', 'test_summary.json' ] } summary_path = os.path.join(output_dir, 'test_summary.json') with open(summary_path, 'w', encoding='utf-8') as f: json.dump(summary, f, ensure_ascii=False, indent=2) # Print final results logger.info("=" * 80) logger.info("🎉 Test Evaluation Completed!") logger.info("=" * 80) logger.info(f"📊 Final Results:") logger.info(f" 🔹 Overall Accuracy: {results['overall_accuracy']:.4f}") logger.info(f" 🔹 Total Test Samples: {len(predictions)}") logger.info(f" 🔹 Correct Predictions: {np.sum(labels == predictions)}") logger.info(f" 🔹 Evaluation Time: {evaluation_time}") logger.info(f"\n📈 Class Performance:") logger.info(f" Class 0 (Same Paragraph):") logger.info(f" • Precision: {results['class_metrics']['class_0_same_paragraph']['precision']:.4f}") logger.info(f" • Recall: {results['class_metrics']['class_0_same_paragraph']['recall']:.4f}") logger.info(f" • F1-Score: {results['class_metrics']['class_0_same_paragraph']['f1_score']:.4f}") logger.info(f" Class 1 (Different Paragraph):") logger.info(f" • Precision: {results['class_metrics']['class_1_different_paragraph']['precision']:.4f}") logger.info(f" • Recall: {results['class_metrics']['class_1_different_paragraph']['recall']:.4f}") logger.info(f" • F1-Score: {results['class_metrics']['class_1_different_paragraph']['f1_score']:.4f}") logger.info(f"\n📁 Generated Files in {output_dir}:") logger.info(f" 📄 test_summary.json - Test evaluation summary") logger.info(f" 📄 test_results_detailed.json - Detailed test results") logger.info(f" 📄 detailed_predictions.json - Individual predictions") logger.info(f" 📊 confusion_matrix_test_results.png - Confusion matrix") logger.info(f" 📊 class_performance_analysis.png - Performance analysis") logger.info(f"\n🎯 Model Performance Summary:") logger.info(f" ✅ Dual Path Boundary Classifier successfully evaluated") logger.info(f" ✅ Optimized for RTX 4060 with max_length=384") logger.info(f" ✅ English visualizations generated") logger.info(f" ✅ All results saved to: {output_dir}") except Exception as e: logger.error(f"❌ Error during evaluation: {str(e)}") import traceback traceback.print_exc() raise if __name__ == "__main__": main()