You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

658 lines
25 KiB

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, precision_recall_fscore_support
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
import logging
import os
import json
from datetime import datetime
import math
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 设置matplotlib英文字体
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def check_gpu_availability():
"""检查GPU可用性"""
if torch.cuda.is_available():
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
logger.info(f"✅ GPU Available!")
logger.info(f" 🔹 GPU Count: {gpu_count}")
logger.info(f" 🔹 GPU Model: {gpu_name}")
logger.info(f" 🔹 GPU Memory: {gpu_memory:.1f} GB")
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
return True, gpu_memory
else:
logger.info(" GPU not available, using CPU")
return False, 0
class FocalLoss(nn.Module):
"""Focal Loss for handling class imbalance"""
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class ScaledDotProductAttention(nn.Module):
"""Scaled Dot Product Attention"""
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class DualPathBoundaryClassifier(nn.Module):
"""Dual Path Boundary Classifier with Pure Neural Network Learning"""
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0):
super(DualPathBoundaryClassifier, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
# Dual path classifiers
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels)
self.boundary_detector = nn.Linear(self.config.hidden_size, 1)
# Boundary force weight
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight))
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
"""Initialize weights for new layers"""
nn.init.normal_(self.regular_classifier.weight, std=0.02)
nn.init.zeros_(self.regular_classifier.bias)
nn.init.normal_(self.boundary_classifier.weight, std=0.02)
nn.init.zeros_(self.boundary_classifier.bias)
nn.init.normal_(self.boundary_detector.weight, std=0.02)
nn.init.zeros_(self.boundary_detector.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
# Enhanced with scaled attention
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
# Dual path classification
regular_logits = self.regular_classifier(cls_output)
boundary_logits = self.boundary_classifier(cls_output)
# Boundary detection
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1)
boundary_score = torch.sigmoid(boundary_logits_raw)
# Dynamic fusion
boundary_bias = torch.zeros_like(regular_logits)
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight
final_logits = regular_logits + boundary_bias
loss = None
if labels is not None:
regular_loss = self.focal_loss(regular_logits, labels)
boundary_loss = self.focal_loss(boundary_logits, labels)
final_loss = self.focal_loss(final_logits, labels)
boundary_labels = self._generate_boundary_labels(labels)
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels)
total_loss = (0.4 * final_loss +
0.3 * regular_loss +
0.2 * boundary_loss +
0.1 * detection_loss)
loss = total_loss
return {
'loss': loss,
'logits': final_logits,
'regular_logits': regular_logits,
'boundary_logits': boundary_logits,
'boundary_score': boundary_score,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def _generate_boundary_labels(self, labels):
"""Generate heuristic labels for boundary detection"""
boundary_labels = labels.float()
noise = torch.rand_like(boundary_labels) * 0.1
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0)
return boundary_labels
class SentencePairTestDataset(Dataset):
"""Sentence pair test dataset"""
def __init__(self, data, tokenizer, max_length=384):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
# Filter valid data
self.valid_data = [item for item in data if item['label'] in [0, 1]]
logger.info(f"Original test data: {len(data)} items, Valid data: {len(self.valid_data)} items")
self.sentence1_list = [item['sentence1'] for item in self.valid_data]
self.sentence2_list = [item['sentence2'] for item in self.valid_data]
self.labels = [item['label'] for item in self.valid_data]
def __len__(self):
return len(self.valid_data)
def __getitem__(self, idx):
sentence1 = str(self.sentence1_list[idx])
sentence2 = str(self.sentence2_list[idx])
label = self.labels[idx]
encoding = self.tokenizer(
sentence1,
sentence2,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long),
'sentence1': sentence1,
'sentence2': sentence2
}
def load_trained_model(model_path, tokenizer_path):
"""Load the trained dual path model"""
logger.info(f"Loading trained model from: {model_path}")
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
# Load model configuration
config_path = os.path.join(model_path, 'config.json')
if os.path.exists(config_path):
with open(config_path, 'r', encoding='utf-8') as f:
model_config = json.load(f)
logger.info(f"Model config loaded: {model_config.get('model_type', 'Unknown')}")
# Initialize model
model = DualPathBoundaryClassifier(
model_path=tokenizer_path, # Use original model path for RoBERTa base
num_labels=2,
dropout=0.1,
focal_alpha=None,
focal_gamma=3.0,
boundary_force_weight=2.0
)
# Load trained weights
model_weights_path = os.path.join(model_path, 'pytorch_model.bin')
if os.path.exists(model_weights_path):
model.load_state_dict(torch.load(model_weights_path, map_location='cpu'))
logger.info("✅ Trained model weights loaded successfully")
else:
raise FileNotFoundError(f"Model weights not found at: {model_weights_path}")
return model, tokenizer
def evaluate_model(model, test_dataloader, device, output_dir):
"""Evaluate the model on test dataset"""
model.eval()
all_predictions = []
all_labels = []
all_probabilities = []
all_boundary_scores = []
logger.info("🔍 Starting model evaluation...")
with torch.no_grad():
for batch_idx, batch in enumerate(test_dataloader):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask
)
logits = outputs['logits']
boundary_scores = outputs['boundary_score']
# Get predictions
probabilities = torch.softmax(logits, dim=-1)
predictions = torch.argmax(logits, dim=-1)
all_predictions.extend(predictions.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_probabilities.extend(probabilities.cpu().numpy())
all_boundary_scores.extend(boundary_scores.cpu().numpy())
if (batch_idx + 1) % 100 == 0:
logger.info(f" Processed {batch_idx + 1} batches...")
# Convert to numpy arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels)
all_probabilities = np.array(all_probabilities)
all_boundary_scores = np.array(all_boundary_scores)
# Calculate metrics
accuracy = accuracy_score(all_labels, all_predictions)
precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_predictions, average=None)
# Generate detailed classification report
class_report = classification_report(all_labels, all_predictions,
target_names=['Same Paragraph (0)', 'Different Paragraph (1)'],
output_dict=True)
# Generate confusion matrix
cm = confusion_matrix(all_labels, all_predictions)
logger.info("📊 Test Results:")
logger.info(f" Overall Accuracy: {accuracy:.4f}")
logger.info(
f" Class 0 (Same Paragraph) - Precision: {precision[0]:.4f}, Recall: {recall[0]:.4f}, F1: {f1[0]:.4f}")
logger.info(
f" Class 1 (Different Paragraph) - Precision: {precision[1]:.4f}, Recall: {recall[1]:.4f}, F1: {f1[1]:.4f}")
# Save results
results = {
'overall_accuracy': float(accuracy),
'class_metrics': {
'class_0_same_paragraph': {
'precision': float(precision[0]),
'recall': float(recall[0]),
'f1_score': float(f1[0]),
'support': int(support[0])
},
'class_1_different_paragraph': {
'precision': float(precision[1]),
'recall': float(recall[1]),
'f1_score': float(f1[1]),
'support': int(support[1])
}
},
'confusion_matrix': cm.tolist(),
'classification_report': class_report,
'test_samples_count': len(all_predictions),
'boundary_score_stats': {
'mean': float(np.mean(all_boundary_scores)),
'std': float(np.std(all_boundary_scores)),
'min': float(np.min(all_boundary_scores)),
'max': float(np.max(all_boundary_scores))
}
}
return results, cm, all_predictions, all_labels, all_probabilities, all_boundary_scores
def plot_confusion_matrix(cm, output_dir, model_name="Dual Path Boundary Classifier"):
"""Plot and save confusion matrix"""
plt.figure(figsize=(10, 8))
# Create heatmap
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'],
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'],
cbar_kws={'label': 'Number of Samples'})
plt.title(f'Confusion Matrix - {model_name}\nTest Dataset Evaluation',
fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Label', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=14, fontweight='bold')
# Add accuracy text
accuracy = np.trace(cm) / np.sum(cm)
plt.text(0.5, -0.15, f'Overall Accuracy: {accuracy:.4f}',
ha='center', transform=plt.gca().transAxes, fontsize=12,
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
# Add sample counts
total_samples = np.sum(cm)
plt.text(0.5, -0.25, f'Total Test Samples: {total_samples}',
ha='center', transform=plt.gca().transAxes, fontsize=10)
plt.tight_layout()
# Save plot
save_path = os.path.join(output_dir, 'confusion_matrix_test_results.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"📊 Confusion matrix saved: {save_path}")
return save_path
def plot_class_distribution(results, output_dir):
"""Plot class distribution and performance metrics"""
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
# Class distribution
class_0_count = results['class_metrics']['class_0_same_paragraph']['support']
class_1_count = results['class_metrics']['class_1_different_paragraph']['support']
ax1.bar(['Same Paragraph (0)', 'Different Paragraph (1)'],
[class_0_count, class_1_count],
color=['skyblue', 'lightcoral'])
ax1.set_title('Test Dataset Class Distribution', fontweight='bold')
ax1.set_ylabel('Number of Samples')
# Add count labels on bars
ax1.text(0, class_0_count + max(class_0_count, class_1_count) * 0.01, str(class_0_count),
ha='center', fontweight='bold')
ax1.text(1, class_1_count + max(class_0_count, class_1_count) * 0.01, str(class_1_count),
ha='center', fontweight='bold')
# Precision comparison
precision_0 = results['class_metrics']['class_0_same_paragraph']['precision']
precision_1 = results['class_metrics']['class_1_different_paragraph']['precision']
ax2.bar(['Same Paragraph (0)', 'Different Paragraph (1)'],
[precision_0, precision_1],
color=['lightgreen', 'orange'])
ax2.set_title('Precision by Class', fontweight='bold')
ax2.set_ylabel('Precision Score')
ax2.set_ylim(0, 1)
# Recall comparison
recall_0 = results['class_metrics']['class_0_same_paragraph']['recall']
recall_1 = results['class_metrics']['class_1_different_paragraph']['recall']
ax3.bar(['Same Paragraph (0)', 'Different Paragraph (1)'],
[recall_0, recall_1],
color=['lightcyan', 'plum'])
ax3.set_title('Recall by Class', fontweight='bold')
ax3.set_ylabel('Recall Score')
ax3.set_ylim(0, 1)
# F1-Score comparison
f1_0 = results['class_metrics']['class_0_same_paragraph']['f1_score']
f1_1 = results['class_metrics']['class_1_different_paragraph']['f1_score']
ax4.bar(['Same Paragraph (0)', 'Different Paragraph (1)'],
[f1_0, f1_1],
color=['gold', 'mediumpurple'])
ax4.set_title('F1-Score by Class', fontweight='bold')
ax4.set_ylabel('F1-Score')
ax4.set_ylim(0, 1)
plt.suptitle('Dual Path Boundary Classifier - Test Performance Analysis',
fontsize=16, fontweight='bold')
plt.tight_layout()
# Save plot
save_path = os.path.join(output_dir, 'class_performance_analysis.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"📊 Class performance analysis saved: {save_path}")
return save_path
def main():
"""Main function"""
logger.info("=" * 80)
logger.info("🚀 Dual Path Boundary Classifier - Test Evaluation")
logger.info("=" * 80)
# Configuration
original_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model"
trained_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train"
test_file = r"D:\workstation\AI标注\数据清洗+json\test_dataset.json"
output_dir = r"D:\workstation\AI标注\test"
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Check GPU availability
gpu_available, gpu_memory = check_gpu_availability()
device = torch.device('cuda' if gpu_available else 'cpu')
logger.info(f"📋 Test Configuration:")
logger.info(f" 🔹 Original Model: {original_model_path}")
logger.info(f" 🔹 Trained Model: {trained_model_path}")
logger.info(f" 🔹 Test Dataset: {test_file}")
logger.info(f" 🔹 Output Directory: {output_dir}")
logger.info(f" 🔹 Max Length: 384 tokens")
logger.info(f" 🔹 Device: {device}")
try:
# Load test data
logger.info("📥 Loading test dataset...")
with open(test_file, 'r', encoding='utf-8') as f:
test_data = json.load(f)
logger.info(f" Loaded {len(test_data)} test samples")
# Load trained model
model, tokenizer = load_trained_model(trained_model_path, original_model_path)
model = model.to(device)
# Create test dataset
test_dataset = SentencePairTestDataset(test_data, tokenizer, max_length=384)
test_dataloader = DataLoader(
test_dataset,
batch_size=32, # Optimized for RTX 4060
shuffle=False,
num_workers=4,
pin_memory=True if gpu_available else False
)
logger.info(f" Test dataset size: {len(test_dataset)}")
logger.info(f" Batch size: 32")
# Evaluate model
start_time = datetime.now()
results, cm, predictions, labels, probabilities, boundary_scores = evaluate_model(
model, test_dataloader, device, output_dir
)
end_time = datetime.now()
evaluation_time = end_time - start_time
logger.info(f" Evaluation completed in: {evaluation_time}")
# Generate visualizations
logger.info("📊 Generating visualizations...")
# Plot confusion matrix
cm_path = plot_confusion_matrix(cm, output_dir)
# Plot class performance analysis
perf_path = plot_class_distribution(results, output_dir)
# Save detailed results
results['evaluation_info'] = {
'evaluation_time': str(evaluation_time),
'device_used': str(device),
'model_type': 'DualPathBoundaryClassifier',
'max_length': 384,
'batch_size': 32,
'test_file': test_file,
'trained_model_path': trained_model_path
}
results_path = os.path.join(output_dir, 'test_results_detailed.json')
with open(results_path, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
# Save predictions
predictions_data = []
for i in range(len(predictions)):
predictions_data.append({
'index': i,
'sentence1': test_dataset[i]['sentence1'],
'sentence2': test_dataset[i]['sentence2'],
'true_label': int(labels[i]),
'predicted_label': int(predictions[i]),
'probability_class_0': float(probabilities[i][0]),
'probability_class_1': float(probabilities[i][1]),
'boundary_score': float(boundary_scores[i]),
'correct': bool(labels[i] == predictions[i])
})
predictions_path = os.path.join(output_dir, 'detailed_predictions.json')
with open(predictions_path, 'w', encoding='utf-8') as f:
json.dump(predictions_data, f, ensure_ascii=False, indent=2)
# Generate summary report
summary = {
'model_info': {
'model_type': 'Dual Path Boundary Classifier',
'base_model': 'Chinese-RoBERTa-WWM-Ext',
'max_length': 384,
'trained_model_path': trained_model_path
},
'test_results': {
'overall_accuracy': results['overall_accuracy'],
'total_samples': len(predictions),
'correct_predictions': int(np.sum(labels == predictions)),
'incorrect_predictions': int(np.sum(labels != predictions))
},
'class_performance': results['class_metrics'],
'boundary_detection': results['boundary_score_stats'],
'files_generated': [
'test_results_detailed.json',
'detailed_predictions.json',
'confusion_matrix_test_results.png',
'class_performance_analysis.png',
'test_summary.json'
]
}
summary_path = os.path.join(output_dir, 'test_summary.json')
with open(summary_path, 'w', encoding='utf-8') as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
# Print final results
logger.info("=" * 80)
logger.info("🎉 Test Evaluation Completed!")
logger.info("=" * 80)
logger.info(f"📊 Final Results:")
logger.info(f" 🔹 Overall Accuracy: {results['overall_accuracy']:.4f}")
logger.info(f" 🔹 Total Test Samples: {len(predictions)}")
logger.info(f" 🔹 Correct Predictions: {np.sum(labels == predictions)}")
logger.info(f" 🔹 Evaluation Time: {evaluation_time}")
logger.info(f"\n📈 Class Performance:")
logger.info(f" Class 0 (Same Paragraph):")
logger.info(f" • Precision: {results['class_metrics']['class_0_same_paragraph']['precision']:.4f}")
logger.info(f" • Recall: {results['class_metrics']['class_0_same_paragraph']['recall']:.4f}")
logger.info(f" • F1-Score: {results['class_metrics']['class_0_same_paragraph']['f1_score']:.4f}")
logger.info(f" Class 1 (Different Paragraph):")
logger.info(f" • Precision: {results['class_metrics']['class_1_different_paragraph']['precision']:.4f}")
logger.info(f" • Recall: {results['class_metrics']['class_1_different_paragraph']['recall']:.4f}")
logger.info(f" • F1-Score: {results['class_metrics']['class_1_different_paragraph']['f1_score']:.4f}")
logger.info(f"\n📁 Generated Files in {output_dir}:")
logger.info(f" 📄 test_summary.json - Test evaluation summary")
logger.info(f" 📄 test_results_detailed.json - Detailed test results")
logger.info(f" 📄 detailed_predictions.json - Individual predictions")
logger.info(f" 📊 confusion_matrix_test_results.png - Confusion matrix")
logger.info(f" 📊 class_performance_analysis.png - Performance analysis")
logger.info(f"\n🎯 Model Performance Summary:")
logger.info(f" ✅ Dual Path Boundary Classifier successfully evaluated")
logger.info(f" ✅ Optimized for RTX 4060 with max_length=384")
logger.info(f" ✅ English visualizations generated")
logger.info(f" ✅ All results saved to: {output_dir}")
except Exception as e:
logger.error(f"❌ Error during evaluation: {str(e)}")
import traceback
traceback.print_exc()
raise
if __name__ == "__main__":
main()