|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, precision_recall_fscore_support |
|
|
from transformers import BertTokenizer, BertModel |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import logging |
|
|
import os |
|
|
import json |
|
|
from datetime import datetime |
|
|
import math |
|
|
|
|
|
# 设置日志 |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
# 设置matplotlib英文字体 |
|
|
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans'] |
|
|
plt.rcParams['axes.unicode_minus'] = False |
|
|
|
|
|
|
|
|
def check_gpu_availability(): |
|
|
"""检查GPU可用性""" |
|
|
if torch.cuda.is_available(): |
|
|
gpu_count = torch.cuda.device_count() |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 |
|
|
|
|
|
logger.info(f"✅ GPU Available!") |
|
|
logger.info(f" 🔹 GPU Count: {gpu_count}") |
|
|
logger.info(f" 🔹 GPU Model: {gpu_name}") |
|
|
logger.info(f" 🔹 GPU Memory: {gpu_memory:.1f} GB") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
torch.backends.cudnn.benchmark = True |
|
|
return True, gpu_memory |
|
|
else: |
|
|
logger.info("⚠️ GPU not available, using CPU") |
|
|
return False, 0 |
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
"""Focal Loss for handling class imbalance""" |
|
|
|
|
|
def __init__(self, alpha=None, gamma=3.0, reduction='mean'): |
|
|
super(FocalLoss, self).__init__() |
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, inputs, targets): |
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none') |
|
|
pt = torch.exp(-ce_loss) |
|
|
|
|
|
if self.alpha is not None: |
|
|
if self.alpha.type() != inputs.data.type(): |
|
|
self.alpha = self.alpha.type_as(inputs.data) |
|
|
at = self.alpha.gather(0, targets.data.view(-1)) |
|
|
ce_loss = ce_loss * at |
|
|
|
|
|
focal_weight = (1 - pt) ** self.gamma |
|
|
focal_loss = focal_weight * ce_loss |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
return focal_loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
return focal_loss.sum() |
|
|
else: |
|
|
return focal_loss |
|
|
|
|
|
|
|
|
class ScaledDotProductAttention(nn.Module): |
|
|
"""Scaled Dot Product Attention""" |
|
|
|
|
|
def __init__(self, d_model, dropout=0.1): |
|
|
super(ScaledDotProductAttention, self).__init__() |
|
|
self.d_model = d_model |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, query, key, value, mask=None): |
|
|
batch_size, seq_len, d_model = query.size() |
|
|
|
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) |
|
|
|
|
|
if mask is not None: |
|
|
mask_value = torch.finfo(scores.dtype).min |
|
|
scores = scores.masked_fill(mask == 0, mask_value) |
|
|
|
|
|
attention_weights = F.softmax(scores, dim=-1) |
|
|
attention_weights = self.dropout(attention_weights) |
|
|
|
|
|
output = torch.matmul(attention_weights, value) |
|
|
|
|
|
return output, attention_weights |
|
|
|
|
|
|
|
|
class DualPathBoundaryClassifier(nn.Module): |
|
|
"""Dual Path Boundary Classifier with Pure Neural Network Learning""" |
|
|
|
|
|
def __init__(self, model_path, num_labels=2, dropout=0.1, |
|
|
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0): |
|
|
super(DualPathBoundaryClassifier, self).__init__() |
|
|
|
|
|
self.roberta = BertModel.from_pretrained(model_path) |
|
|
self.config = self.roberta.config |
|
|
self.config.num_labels = num_labels |
|
|
|
|
|
self.scaled_attention = ScaledDotProductAttention( |
|
|
d_model=self.config.hidden_size, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
# Dual path classifiers |
|
|
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels) |
|
|
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels) |
|
|
self.boundary_detector = nn.Linear(self.config.hidden_size, 1) |
|
|
|
|
|
# Boundary force weight |
|
|
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight)) |
|
|
|
|
|
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
self.focal_alpha = focal_alpha |
|
|
self.focal_gamma = focal_gamma |
|
|
|
|
|
def _init_weights(self): |
|
|
"""Initialize weights for new layers""" |
|
|
nn.init.normal_(self.regular_classifier.weight, std=0.02) |
|
|
nn.init.zeros_(self.regular_classifier.bias) |
|
|
nn.init.normal_(self.boundary_classifier.weight, std=0.02) |
|
|
nn.init.zeros_(self.boundary_classifier.bias) |
|
|
nn.init.normal_(self.boundary_detector.weight, std=0.02) |
|
|
nn.init.zeros_(self.boundary_detector.bias) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None): |
|
|
roberta_outputs = self.roberta( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
sequence_output = roberta_outputs.last_hidden_state |
|
|
|
|
|
# Enhanced with scaled attention |
|
|
enhanced_output, attention_weights = self.scaled_attention( |
|
|
query=sequence_output, |
|
|
key=sequence_output, |
|
|
value=sequence_output, |
|
|
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None |
|
|
) |
|
|
|
|
|
cls_output = enhanced_output[:, 0, :] |
|
|
cls_output = self.dropout(cls_output) |
|
|
|
|
|
# Dual path classification |
|
|
regular_logits = self.regular_classifier(cls_output) |
|
|
boundary_logits = self.boundary_classifier(cls_output) |
|
|
|
|
|
# Boundary detection |
|
|
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1) |
|
|
boundary_score = torch.sigmoid(boundary_logits_raw) |
|
|
|
|
|
# Dynamic fusion |
|
|
boundary_bias = torch.zeros_like(regular_logits) |
|
|
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight |
|
|
|
|
|
final_logits = regular_logits + boundary_bias |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
regular_loss = self.focal_loss(regular_logits, labels) |
|
|
boundary_loss = self.focal_loss(boundary_logits, labels) |
|
|
final_loss = self.focal_loss(final_logits, labels) |
|
|
|
|
|
boundary_labels = self._generate_boundary_labels(labels) |
|
|
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels) |
|
|
|
|
|
total_loss = (0.4 * final_loss + |
|
|
0.3 * regular_loss + |
|
|
0.2 * boundary_loss + |
|
|
0.1 * detection_loss) |
|
|
loss = total_loss |
|
|
|
|
|
return { |
|
|
'loss': loss, |
|
|
'logits': final_logits, |
|
|
'regular_logits': regular_logits, |
|
|
'boundary_logits': boundary_logits, |
|
|
'boundary_score': boundary_score, |
|
|
'hidden_states': enhanced_output, |
|
|
'attention_weights': attention_weights |
|
|
} |
|
|
|
|
|
def _generate_boundary_labels(self, labels): |
|
|
"""Generate heuristic labels for boundary detection""" |
|
|
boundary_labels = labels.float() |
|
|
noise = torch.rand_like(boundary_labels) * 0.1 |
|
|
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0) |
|
|
return boundary_labels |
|
|
|
|
|
|
|
|
class SentencePairTestDataset(Dataset): |
|
|
"""Sentence pair test dataset""" |
|
|
|
|
|
def __init__(self, data, tokenizer, max_length=384): |
|
|
self.data = data |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
# Filter valid data |
|
|
self.valid_data = [item for item in data if item['label'] in [0, 1]] |
|
|
logger.info(f"Original test data: {len(data)} items, Valid data: {len(self.valid_data)} items") |
|
|
|
|
|
self.sentence1_list = [item['sentence1'] for item in self.valid_data] |
|
|
self.sentence2_list = [item['sentence2'] for item in self.valid_data] |
|
|
self.labels = [item['label'] for item in self.valid_data] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.valid_data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sentence1 = str(self.sentence1_list[idx]) |
|
|
sentence2 = str(self.sentence2_list[idx]) |
|
|
label = self.labels[idx] |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
sentence1, |
|
|
sentence2, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=self.max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].flatten(), |
|
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
|
'labels': torch.tensor(label, dtype=torch.long), |
|
|
'sentence1': sentence1, |
|
|
'sentence2': sentence2 |
|
|
} |
|
|
|
|
|
|
|
|
def load_trained_model(model_path, tokenizer_path): |
|
|
"""Load the trained dual path model""" |
|
|
logger.info(f"Loading trained model from: {model_path}") |
|
|
|
|
|
# Load tokenizer |
|
|
tokenizer = BertTokenizer.from_pretrained(tokenizer_path) |
|
|
|
|
|
# Load model configuration |
|
|
config_path = os.path.join(model_path, 'config.json') |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path, 'r', encoding='utf-8') as f: |
|
|
model_config = json.load(f) |
|
|
logger.info(f"Model config loaded: {model_config.get('model_type', 'Unknown')}") |
|
|
|
|
|
# Initialize model |
|
|
model = DualPathBoundaryClassifier( |
|
|
model_path=tokenizer_path, # Use original model path for RoBERTa base |
|
|
num_labels=2, |
|
|
dropout=0.1, |
|
|
focal_alpha=None, |
|
|
focal_gamma=3.0, |
|
|
boundary_force_weight=2.0 |
|
|
) |
|
|
|
|
|
# Load trained weights |
|
|
model_weights_path = os.path.join(model_path, 'pytorch_model.bin') |
|
|
if os.path.exists(model_weights_path): |
|
|
model.load_state_dict(torch.load(model_weights_path, map_location='cpu')) |
|
|
logger.info("✅ Trained model weights loaded successfully") |
|
|
else: |
|
|
raise FileNotFoundError(f"Model weights not found at: {model_weights_path}") |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def evaluate_model(model, test_dataloader, device, output_dir): |
|
|
"""Evaluate the model on test dataset""" |
|
|
model.eval() |
|
|
all_predictions = [] |
|
|
all_labels = [] |
|
|
all_probabilities = [] |
|
|
all_boundary_scores = [] |
|
|
|
|
|
logger.info("🔍 Starting model evaluation...") |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch_idx, batch in enumerate(test_dataloader): |
|
|
input_ids = batch['input_ids'].to(device) |
|
|
attention_mask = batch['attention_mask'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
outputs = model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
logits = outputs['logits'] |
|
|
boundary_scores = outputs['boundary_score'] |
|
|
|
|
|
# Get predictions |
|
|
probabilities = torch.softmax(logits, dim=-1) |
|
|
predictions = torch.argmax(logits, dim=-1) |
|
|
|
|
|
all_predictions.extend(predictions.cpu().numpy()) |
|
|
all_labels.extend(labels.cpu().numpy()) |
|
|
all_probabilities.extend(probabilities.cpu().numpy()) |
|
|
all_boundary_scores.extend(boundary_scores.cpu().numpy()) |
|
|
|
|
|
if (batch_idx + 1) % 100 == 0: |
|
|
logger.info(f" Processed {batch_idx + 1} batches...") |
|
|
|
|
|
# Convert to numpy arrays |
|
|
all_predictions = np.array(all_predictions) |
|
|
all_labels = np.array(all_labels) |
|
|
all_probabilities = np.array(all_probabilities) |
|
|
all_boundary_scores = np.array(all_boundary_scores) |
|
|
|
|
|
# Calculate metrics |
|
|
accuracy = accuracy_score(all_labels, all_predictions) |
|
|
precision, recall, f1, support = precision_recall_fscore_support(all_labels, all_predictions, average=None) |
|
|
|
|
|
# Generate detailed classification report |
|
|
class_report = classification_report(all_labels, all_predictions, |
|
|
target_names=['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
output_dict=True) |
|
|
|
|
|
# Generate confusion matrix |
|
|
cm = confusion_matrix(all_labels, all_predictions) |
|
|
|
|
|
logger.info("📊 Test Results:") |
|
|
logger.info(f" Overall Accuracy: {accuracy:.4f}") |
|
|
logger.info( |
|
|
f" Class 0 (Same Paragraph) - Precision: {precision[0]:.4f}, Recall: {recall[0]:.4f}, F1: {f1[0]:.4f}") |
|
|
logger.info( |
|
|
f" Class 1 (Different Paragraph) - Precision: {precision[1]:.4f}, Recall: {recall[1]:.4f}, F1: {f1[1]:.4f}") |
|
|
|
|
|
# Save results |
|
|
results = { |
|
|
'overall_accuracy': float(accuracy), |
|
|
'class_metrics': { |
|
|
'class_0_same_paragraph': { |
|
|
'precision': float(precision[0]), |
|
|
'recall': float(recall[0]), |
|
|
'f1_score': float(f1[0]), |
|
|
'support': int(support[0]) |
|
|
}, |
|
|
'class_1_different_paragraph': { |
|
|
'precision': float(precision[1]), |
|
|
'recall': float(recall[1]), |
|
|
'f1_score': float(f1[1]), |
|
|
'support': int(support[1]) |
|
|
} |
|
|
}, |
|
|
'confusion_matrix': cm.tolist(), |
|
|
'classification_report': class_report, |
|
|
'test_samples_count': len(all_predictions), |
|
|
'boundary_score_stats': { |
|
|
'mean': float(np.mean(all_boundary_scores)), |
|
|
'std': float(np.std(all_boundary_scores)), |
|
|
'min': float(np.min(all_boundary_scores)), |
|
|
'max': float(np.max(all_boundary_scores)) |
|
|
} |
|
|
} |
|
|
|
|
|
return results, cm, all_predictions, all_labels, all_probabilities, all_boundary_scores |
|
|
|
|
|
|
|
|
def plot_confusion_matrix(cm, output_dir, model_name="Dual Path Boundary Classifier"): |
|
|
"""Plot and save confusion matrix""" |
|
|
plt.figure(figsize=(10, 8)) |
|
|
|
|
|
# Create heatmap |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
cbar_kws={'label': 'Number of Samples'}) |
|
|
|
|
|
plt.title(f'Confusion Matrix - {model_name}\nTest Dataset Evaluation', |
|
|
fontsize=16, fontweight='bold', pad=20) |
|
|
plt.xlabel('Predicted Label', fontsize=14, fontweight='bold') |
|
|
plt.ylabel('True Label', fontsize=14, fontweight='bold') |
|
|
|
|
|
# Add accuracy text |
|
|
accuracy = np.trace(cm) / np.sum(cm) |
|
|
plt.text(0.5, -0.15, f'Overall Accuracy: {accuracy:.4f}', |
|
|
ha='center', transform=plt.gca().transAxes, fontsize=12, |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7)) |
|
|
|
|
|
# Add sample counts |
|
|
total_samples = np.sum(cm) |
|
|
plt.text(0.5, -0.25, f'Total Test Samples: {total_samples}', |
|
|
ha='center', transform=plt.gca().transAxes, fontsize=10) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
# Save plot |
|
|
save_path = os.path.join(output_dir, 'confusion_matrix_test_results.png') |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logger.info(f"📊 Confusion matrix saved: {save_path}") |
|
|
return save_path |
|
|
|
|
|
|
|
|
def plot_class_distribution(results, output_dir): |
|
|
"""Plot class distribution and performance metrics""" |
|
|
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12)) |
|
|
|
|
|
# Class distribution |
|
|
class_0_count = results['class_metrics']['class_0_same_paragraph']['support'] |
|
|
class_1_count = results['class_metrics']['class_1_different_paragraph']['support'] |
|
|
|
|
|
ax1.bar(['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
[class_0_count, class_1_count], |
|
|
color=['skyblue', 'lightcoral']) |
|
|
ax1.set_title('Test Dataset Class Distribution', fontweight='bold') |
|
|
ax1.set_ylabel('Number of Samples') |
|
|
|
|
|
# Add count labels on bars |
|
|
ax1.text(0, class_0_count + max(class_0_count, class_1_count) * 0.01, str(class_0_count), |
|
|
ha='center', fontweight='bold') |
|
|
ax1.text(1, class_1_count + max(class_0_count, class_1_count) * 0.01, str(class_1_count), |
|
|
ha='center', fontweight='bold') |
|
|
|
|
|
# Precision comparison |
|
|
precision_0 = results['class_metrics']['class_0_same_paragraph']['precision'] |
|
|
precision_1 = results['class_metrics']['class_1_different_paragraph']['precision'] |
|
|
|
|
|
ax2.bar(['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
[precision_0, precision_1], |
|
|
color=['lightgreen', 'orange']) |
|
|
ax2.set_title('Precision by Class', fontweight='bold') |
|
|
ax2.set_ylabel('Precision Score') |
|
|
ax2.set_ylim(0, 1) |
|
|
|
|
|
# Recall comparison |
|
|
recall_0 = results['class_metrics']['class_0_same_paragraph']['recall'] |
|
|
recall_1 = results['class_metrics']['class_1_different_paragraph']['recall'] |
|
|
|
|
|
ax3.bar(['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
[recall_0, recall_1], |
|
|
color=['lightcyan', 'plum']) |
|
|
ax3.set_title('Recall by Class', fontweight='bold') |
|
|
ax3.set_ylabel('Recall Score') |
|
|
ax3.set_ylim(0, 1) |
|
|
|
|
|
# F1-Score comparison |
|
|
f1_0 = results['class_metrics']['class_0_same_paragraph']['f1_score'] |
|
|
f1_1 = results['class_metrics']['class_1_different_paragraph']['f1_score'] |
|
|
|
|
|
ax4.bar(['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
[f1_0, f1_1], |
|
|
color=['gold', 'mediumpurple']) |
|
|
ax4.set_title('F1-Score by Class', fontweight='bold') |
|
|
ax4.set_ylabel('F1-Score') |
|
|
ax4.set_ylim(0, 1) |
|
|
|
|
|
plt.suptitle('Dual Path Boundary Classifier - Test Performance Analysis', |
|
|
fontsize=16, fontweight='bold') |
|
|
plt.tight_layout() |
|
|
|
|
|
# Save plot |
|
|
save_path = os.path.join(output_dir, 'class_performance_analysis.png') |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logger.info(f"📊 Class performance analysis saved: {save_path}") |
|
|
return save_path |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function""" |
|
|
logger.info("=" * 80) |
|
|
logger.info("🚀 Dual Path Boundary Classifier - Test Evaluation") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
# Configuration |
|
|
original_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model" |
|
|
trained_model_path = r"D:\workstation\chinese-roberta-wwm-ext\model-train-eval-NN\model_train" |
|
|
test_file = r"D:\workstation\AI标注\数据清洗+json\test_dataset.json" |
|
|
output_dir = r"D:\workstation\AI标注\test" |
|
|
|
|
|
# Ensure output directory exists |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
# Check GPU availability |
|
|
gpu_available, gpu_memory = check_gpu_availability() |
|
|
device = torch.device('cuda' if gpu_available else 'cpu') |
|
|
|
|
|
logger.info(f"📋 Test Configuration:") |
|
|
logger.info(f" 🔹 Original Model: {original_model_path}") |
|
|
logger.info(f" 🔹 Trained Model: {trained_model_path}") |
|
|
logger.info(f" 🔹 Test Dataset: {test_file}") |
|
|
logger.info(f" 🔹 Output Directory: {output_dir}") |
|
|
logger.info(f" 🔹 Max Length: 384 tokens") |
|
|
logger.info(f" 🔹 Device: {device}") |
|
|
|
|
|
try: |
|
|
# Load test data |
|
|
logger.info("📥 Loading test dataset...") |
|
|
with open(test_file, 'r', encoding='utf-8') as f: |
|
|
test_data = json.load(f) |
|
|
logger.info(f" Loaded {len(test_data)} test samples") |
|
|
|
|
|
# Load trained model |
|
|
model, tokenizer = load_trained_model(trained_model_path, original_model_path) |
|
|
model = model.to(device) |
|
|
|
|
|
# Create test dataset |
|
|
test_dataset = SentencePairTestDataset(test_data, tokenizer, max_length=384) |
|
|
test_dataloader = DataLoader( |
|
|
test_dataset, |
|
|
batch_size=32, # Optimized for RTX 4060 |
|
|
shuffle=False, |
|
|
num_workers=4, |
|
|
pin_memory=True if gpu_available else False |
|
|
) |
|
|
|
|
|
logger.info(f" Test dataset size: {len(test_dataset)}") |
|
|
logger.info(f" Batch size: 32") |
|
|
|
|
|
# Evaluate model |
|
|
start_time = datetime.now() |
|
|
results, cm, predictions, labels, probabilities, boundary_scores = evaluate_model( |
|
|
model, test_dataloader, device, output_dir |
|
|
) |
|
|
end_time = datetime.now() |
|
|
|
|
|
evaluation_time = end_time - start_time |
|
|
logger.info(f"⏱️ Evaluation completed in: {evaluation_time}") |
|
|
|
|
|
# Generate visualizations |
|
|
logger.info("📊 Generating visualizations...") |
|
|
|
|
|
# Plot confusion matrix |
|
|
cm_path = plot_confusion_matrix(cm, output_dir) |
|
|
|
|
|
# Plot class performance analysis |
|
|
perf_path = plot_class_distribution(results, output_dir) |
|
|
|
|
|
# Save detailed results |
|
|
results['evaluation_info'] = { |
|
|
'evaluation_time': str(evaluation_time), |
|
|
'device_used': str(device), |
|
|
'model_type': 'DualPathBoundaryClassifier', |
|
|
'max_length': 384, |
|
|
'batch_size': 32, |
|
|
'test_file': test_file, |
|
|
'trained_model_path': trained_model_path |
|
|
} |
|
|
|
|
|
results_path = os.path.join(output_dir, 'test_results_detailed.json') |
|
|
with open(results_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(results, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# Save predictions |
|
|
predictions_data = [] |
|
|
for i in range(len(predictions)): |
|
|
predictions_data.append({ |
|
|
'index': i, |
|
|
'sentence1': test_dataset[i]['sentence1'], |
|
|
'sentence2': test_dataset[i]['sentence2'], |
|
|
'true_label': int(labels[i]), |
|
|
'predicted_label': int(predictions[i]), |
|
|
'probability_class_0': float(probabilities[i][0]), |
|
|
'probability_class_1': float(probabilities[i][1]), |
|
|
'boundary_score': float(boundary_scores[i]), |
|
|
'correct': bool(labels[i] == predictions[i]) |
|
|
}) |
|
|
|
|
|
predictions_path = os.path.join(output_dir, 'detailed_predictions.json') |
|
|
with open(predictions_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(predictions_data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# Generate summary report |
|
|
summary = { |
|
|
'model_info': { |
|
|
'model_type': 'Dual Path Boundary Classifier', |
|
|
'base_model': 'Chinese-RoBERTa-WWM-Ext', |
|
|
'max_length': 384, |
|
|
'trained_model_path': trained_model_path |
|
|
}, |
|
|
'test_results': { |
|
|
'overall_accuracy': results['overall_accuracy'], |
|
|
'total_samples': len(predictions), |
|
|
'correct_predictions': int(np.sum(labels == predictions)), |
|
|
'incorrect_predictions': int(np.sum(labels != predictions)) |
|
|
}, |
|
|
'class_performance': results['class_metrics'], |
|
|
'boundary_detection': results['boundary_score_stats'], |
|
|
'files_generated': [ |
|
|
'test_results_detailed.json', |
|
|
'detailed_predictions.json', |
|
|
'confusion_matrix_test_results.png', |
|
|
'class_performance_analysis.png', |
|
|
'test_summary.json' |
|
|
] |
|
|
} |
|
|
|
|
|
summary_path = os.path.join(output_dir, 'test_summary.json') |
|
|
with open(summary_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(summary, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# Print final results |
|
|
logger.info("=" * 80) |
|
|
logger.info("🎉 Test Evaluation Completed!") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"📊 Final Results:") |
|
|
logger.info(f" 🔹 Overall Accuracy: {results['overall_accuracy']:.4f}") |
|
|
logger.info(f" 🔹 Total Test Samples: {len(predictions)}") |
|
|
logger.info(f" 🔹 Correct Predictions: {np.sum(labels == predictions)}") |
|
|
logger.info(f" 🔹 Evaluation Time: {evaluation_time}") |
|
|
|
|
|
logger.info(f"\n📈 Class Performance:") |
|
|
logger.info(f" Class 0 (Same Paragraph):") |
|
|
logger.info(f" • Precision: {results['class_metrics']['class_0_same_paragraph']['precision']:.4f}") |
|
|
logger.info(f" • Recall: {results['class_metrics']['class_0_same_paragraph']['recall']:.4f}") |
|
|
logger.info(f" • F1-Score: {results['class_metrics']['class_0_same_paragraph']['f1_score']:.4f}") |
|
|
|
|
|
logger.info(f" Class 1 (Different Paragraph):") |
|
|
logger.info(f" • Precision: {results['class_metrics']['class_1_different_paragraph']['precision']:.4f}") |
|
|
logger.info(f" • Recall: {results['class_metrics']['class_1_different_paragraph']['recall']:.4f}") |
|
|
logger.info(f" • F1-Score: {results['class_metrics']['class_1_different_paragraph']['f1_score']:.4f}") |
|
|
|
|
|
logger.info(f"\n📁 Generated Files in {output_dir}:") |
|
|
logger.info(f" 📄 test_summary.json - Test evaluation summary") |
|
|
logger.info(f" 📄 test_results_detailed.json - Detailed test results") |
|
|
logger.info(f" 📄 detailed_predictions.json - Individual predictions") |
|
|
logger.info(f" 📊 confusion_matrix_test_results.png - Confusion matrix") |
|
|
logger.info(f" 📊 class_performance_analysis.png - Performance analysis") |
|
|
|
|
|
logger.info(f"\n🎯 Model Performance Summary:") |
|
|
logger.info(f" ✅ Dual Path Boundary Classifier successfully evaluated") |
|
|
logger.info(f" ✅ Optimized for RTX 4060 with max_length=384") |
|
|
logger.info(f" ✅ English visualizations generated") |
|
|
logger.info(f" ✅ All results saved to: {output_dir}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ Error during evaluation: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |