|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix |
|
|
from sklearn.model_selection import train_test_split |
|
|
from transformers import ( |
|
|
BertTokenizer, |
|
|
BertForSequenceClassification, |
|
|
BertModel, |
|
|
BertConfig, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
DataCollatorWithPadding, |
|
|
TrainerCallback, |
|
|
EarlyStoppingCallback |
|
|
) |
|
|
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
|
|
import logging |
|
|
import os |
|
|
from datetime import datetime |
|
|
import math |
|
|
import json |
|
|
from collections import defaultdict, Counter |
|
|
|
|
|
# 禁用wandb和其他第三方报告工具 |
|
|
os.environ["WANDB_DISABLED"] = "true" |
|
|
os.environ["COMET_MODE"] = "disabled" |
|
|
os.environ["NEPTUNE_MODE"] = "disabled" |
|
|
|
|
|
# 设置日志 |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
# 设置matplotlib中文字体 |
|
|
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] |
|
|
plt.rcParams['axes.unicode_minus'] = False |
|
|
|
|
|
|
|
|
def check_gpu_availability(): |
|
|
"""检查GPU可用性""" |
|
|
if not torch.cuda.is_available(): |
|
|
raise RuntimeError("❌ GPU不可用!此脚本需要GPU支持。") |
|
|
|
|
|
gpu_count = torch.cuda.device_count() |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 |
|
|
|
|
|
logger.info(f"✅ GPU检查通过!") |
|
|
logger.info(f" 🔹 可用GPU数量: {gpu_count}") |
|
|
logger.info(f" 🔹 GPU型号: {gpu_name}") |
|
|
logger.info(f" 🔹 GPU内存: {gpu_memory:.1f} GB") |
|
|
|
|
|
# V100优化设置 |
|
|
torch.cuda.empty_cache() |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
return True, gpu_memory |
|
|
|
|
|
|
|
|
class LossTracker(TrainerCallback): |
|
|
"""损失跟踪回调类""" |
|
|
|
|
|
def __init__(self): |
|
|
self.train_losses = [] |
|
|
self.eval_losses = [] |
|
|
self.eval_accuracies = [] |
|
|
self.train_steps = [] |
|
|
self.eval_steps = [] |
|
|
self.epochs = [] |
|
|
self.current_epoch = 0 |
|
|
self.best_eval_accuracy = 0.0 |
|
|
self.best_epoch = 0 |
|
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
|
if logs: |
|
|
if 'loss' in logs: |
|
|
self.train_losses.append(logs['loss']) |
|
|
self.train_steps.append(state.global_step) |
|
|
if 'eval_loss' in logs: |
|
|
self.eval_losses.append(logs['eval_loss']) |
|
|
self.eval_steps.append(state.global_step) |
|
|
if 'eval_accuracy' in logs: |
|
|
self.eval_accuracies.append(logs['eval_accuracy']) |
|
|
# 记录最佳验证准确率 |
|
|
if logs['eval_accuracy'] > self.best_eval_accuracy: |
|
|
self.best_eval_accuracy = logs['eval_accuracy'] |
|
|
self.best_epoch = self.current_epoch |
|
|
|
|
|
def on_epoch_end(self, args, state, control, **kwargs): |
|
|
self.current_epoch = state.epoch |
|
|
self.epochs.append(state.epoch) |
|
|
|
|
|
|
|
|
class ValidationConfusionMatrixCallback(TrainerCallback): |
|
|
"""验证集混淆矩阵生成回调(每10个epoch)""" |
|
|
|
|
|
def __init__(self, eval_dataset, tokenizer, output_dir, epochs_interval=10): |
|
|
self.eval_dataset = eval_dataset |
|
|
self.tokenizer = tokenizer |
|
|
self.output_dir = output_dir |
|
|
self.epochs_interval = epochs_interval |
|
|
self.confusion_matrices = {} |
|
|
|
|
|
def on_epoch_end(self, args, state, control, model=None, **kwargs): |
|
|
current_epoch = int(state.epoch) |
|
|
|
|
|
# 每10个epoch生成验证集混淆矩阵 |
|
|
if current_epoch % self.epochs_interval == 0 or current_epoch == args.num_train_epochs: |
|
|
logger.info(f"📊 Generating validation confusion matrix for epoch {current_epoch}...") |
|
|
|
|
|
model.eval() |
|
|
predictions = [] |
|
|
true_labels = [] |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(len(self.eval_dataset)): |
|
|
item = self.eval_dataset[i] |
|
|
input_ids = item['input_ids'].unsqueeze(0).to(device) |
|
|
attention_mask = item['attention_mask'].unsqueeze(0).to(device) |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
pred = torch.argmax(outputs['logits'], dim=-1).cpu().item() |
|
|
|
|
|
predictions.append(pred) |
|
|
true_labels.append(item['labels'].item()) |
|
|
|
|
|
cm = confusion_matrix(true_labels, predictions) |
|
|
self.confusion_matrices[current_epoch] = cm |
|
|
self.save_confusion_matrix(cm, current_epoch) |
|
|
model.train() |
|
|
|
|
|
def save_confusion_matrix(self, cm, epoch): |
|
|
"""保存验证集混淆矩阵图""" |
|
|
plt.figure(figsize=(8, 6)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)']) |
|
|
plt.title(f'Validation Confusion Matrix - Epoch {epoch}') |
|
|
plt.xlabel('Predicted Label') |
|
|
plt.ylabel('True Label') |
|
|
|
|
|
accuracy = np.trace(cm) / np.sum(cm) |
|
|
plt.text(0.5, -0.15, f'Validation Accuracy: {accuracy:.4f}', |
|
|
ha='center', transform=plt.gca().transAxes) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
save_path = os.path.join(self.output_dir, f'validation_confusion_matrix_epoch_{epoch}.png') |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logger.info(f" 💾 Validation confusion matrix saved: {save_path}") |
|
|
|
|
|
|
|
|
class TrainingConfusionMatrixCallback(TrainerCallback): |
|
|
"""训练集混淆矩阵生成回调(每20个epoch)""" |
|
|
|
|
|
def __init__(self, train_dataset, tokenizer, output_dir, epochs_interval=20): |
|
|
self.train_dataset = train_dataset |
|
|
self.tokenizer = tokenizer |
|
|
self.output_dir = output_dir |
|
|
self.epochs_interval = epochs_interval |
|
|
self.confusion_matrices = {} |
|
|
|
|
|
def on_epoch_end(self, args, state, control, model=None, **kwargs): |
|
|
current_epoch = int(state.epoch) |
|
|
|
|
|
if current_epoch % self.epochs_interval == 0 or current_epoch == args.num_train_epochs: |
|
|
logger.info(f"📊 Generating training confusion matrix for epoch {current_epoch}...") |
|
|
|
|
|
model.eval() |
|
|
predictions = [] |
|
|
true_labels = [] |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
|
|
|
# 只使用训练集的一个子集来生成混淆矩阵,避免时间过长 |
|
|
subset_size = min(1000, len(self.train_dataset)) |
|
|
indices = np.random.choice(len(self.train_dataset), subset_size, replace=False) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in indices: |
|
|
item = self.train_dataset[i] |
|
|
input_ids = item['input_ids'].unsqueeze(0).to(device) |
|
|
attention_mask = item['attention_mask'].unsqueeze(0).to(device) |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
pred = torch.argmax(outputs['logits'], dim=-1).cpu().item() |
|
|
|
|
|
predictions.append(pred) |
|
|
true_labels.append(item['labels'].item()) |
|
|
|
|
|
cm = confusion_matrix(true_labels, predictions) |
|
|
self.confusion_matrices[current_epoch] = cm |
|
|
self.save_confusion_matrix(cm, current_epoch) |
|
|
model.train() |
|
|
|
|
|
def save_confusion_matrix(self, cm, epoch): |
|
|
"""保存训练集混淆矩阵图""" |
|
|
plt.figure(figsize=(8, 6)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens', |
|
|
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)']) |
|
|
plt.title(f'Training Confusion Matrix - Epoch {epoch}') |
|
|
plt.xlabel('Predicted Label') |
|
|
plt.ylabel('True Label') |
|
|
|
|
|
accuracy = np.trace(cm) / np.sum(cm) |
|
|
plt.text(0.5, -0.15, f'Training Accuracy: {accuracy:.4f}', |
|
|
ha='center', transform=plt.gca().transAxes) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
save_path = os.path.join(self.output_dir, f'training_confusion_matrix_epoch_{epoch}.png') |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logger.info(f" 💾 Training confusion matrix saved: {save_path}") |
|
|
|
|
|
|
|
|
def plot_training_curves(loss_tracker, output_dir): |
|
|
"""绘制训练损失曲线和验证准确率曲线""" |
|
|
plt.figure(figsize=(15, 10)) |
|
|
|
|
|
# 绘制训练损失 |
|
|
if loss_tracker.train_losses: |
|
|
plt.subplot(2, 2, 1) |
|
|
plt.plot(loss_tracker.train_steps, loss_tracker.train_losses, |
|
|
'b-', label='Training Loss', linewidth=2, alpha=0.8) |
|
|
plt.title('Training Loss Curve', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Loss Value') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
if len(loss_tracker.train_losses) > 10: |
|
|
z = np.polyfit(loss_tracker.train_steps, loss_tracker.train_losses, 1) |
|
|
p = np.poly1d(z) |
|
|
plt.plot(loss_tracker.train_steps, p(loss_tracker.train_steps), |
|
|
'r--', alpha=0.6, label='Trend Line') |
|
|
plt.legend() |
|
|
|
|
|
# 绘制验证损失 |
|
|
if loss_tracker.eval_losses: |
|
|
plt.subplot(2, 2, 2) |
|
|
plt.plot(loss_tracker.eval_steps, loss_tracker.eval_losses, |
|
|
'g-', label='Validation Loss', linewidth=2, alpha=0.8) |
|
|
plt.title('Validation Loss Curve', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Loss Value') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 绘制验证准确率 |
|
|
if loss_tracker.eval_accuracies: |
|
|
plt.subplot(2, 2, 3) |
|
|
plt.plot(loss_tracker.eval_steps, loss_tracker.eval_accuracies, |
|
|
'purple', label='Validation Accuracy', linewidth=2, alpha=0.8, marker='o', markersize=3) |
|
|
plt.title('Validation Accuracy Curve', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Accuracy') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 标记最佳准确率 |
|
|
if loss_tracker.best_eval_accuracy > 0: |
|
|
plt.axhline(y=loss_tracker.best_eval_accuracy, color='red', linestyle='--', alpha=0.7) |
|
|
plt.text(0.02, 0.98, f'Best: {loss_tracker.best_eval_accuracy:.4f} (Epoch {loss_tracker.best_epoch})', |
|
|
transform=plt.gca().transAxes, fontsize=10, verticalalignment='top', |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7)) |
|
|
|
|
|
# 绘制训练vs验证损失对比 |
|
|
if loss_tracker.train_losses and loss_tracker.eval_losses: |
|
|
plt.subplot(2, 2, 4) |
|
|
|
|
|
min_len = min(len(loss_tracker.train_losses), len(loss_tracker.eval_losses)) |
|
|
train_steps_aligned = loss_tracker.train_steps[:min_len] |
|
|
train_losses_aligned = loss_tracker.train_losses[:min_len] |
|
|
eval_steps_aligned = loss_tracker.eval_steps[:min_len] |
|
|
eval_losses_aligned = loss_tracker.eval_losses[:min_len] |
|
|
|
|
|
plt.plot(train_steps_aligned, train_losses_aligned, |
|
|
'b-', label='Training Loss', linewidth=2, alpha=0.8) |
|
|
plt.plot(eval_steps_aligned, eval_losses_aligned, |
|
|
'r-', label='Validation Loss', linewidth=2, alpha=0.8) |
|
|
|
|
|
plt.title('Training vs Validation Loss', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Loss Value') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 过拟合检测 |
|
|
if len(eval_losses_aligned) > 20: |
|
|
recent_train = np.mean(train_losses_aligned[-10:]) |
|
|
recent_eval = np.mean(eval_losses_aligned[-10:]) |
|
|
if recent_eval > recent_train * 1.2: |
|
|
plt.text(0.7, 0.9, '⚠️ Possible Overfitting', transform=plt.gca().transAxes, |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="orange", alpha=0.7)) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
# 保存综合训练曲线 |
|
|
curves_path = os.path.join(output_dir, 'comprehensive_training_curves.png') |
|
|
plt.savefig(curves_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logger.info(f"📈 Comprehensive training curves saved: {curves_path}") |
|
|
|
|
|
# 单独保存训练vs验证损失对比图 |
|
|
if loss_tracker.train_losses and loss_tracker.eval_losses: |
|
|
plt.figure(figsize=(12, 6)) |
|
|
|
|
|
min_len = min(len(loss_tracker.train_losses), len(loss_tracker.eval_losses)) |
|
|
train_steps_aligned = loss_tracker.train_steps[:min_len] |
|
|
train_losses_aligned = loss_tracker.train_losses[:min_len] |
|
|
eval_steps_aligned = loss_tracker.eval_steps[:min_len] |
|
|
eval_losses_aligned = loss_tracker.eval_losses[:min_len] |
|
|
|
|
|
plt.plot(train_steps_aligned, train_losses_aligned, |
|
|
'b-', label='Training Loss', linewidth=2, alpha=0.8) |
|
|
plt.plot(eval_steps_aligned, eval_losses_aligned, |
|
|
'r-', label='Validation Loss', linewidth=2, alpha=0.8) |
|
|
|
|
|
plt.title('Training vs Validation Loss Comparison', fontsize=16, fontweight='bold') |
|
|
plt.xlabel('Training Steps', fontsize=12) |
|
|
plt.ylabel('Loss Value', fontsize=12) |
|
|
plt.legend(fontsize=12) |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
compare_path = os.path.join(output_dir, 'loss_comparison_curves.png') |
|
|
plt.savefig(compare_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
logger.info(f"📈 Loss comparison curves saved: {compare_path}") |
|
|
|
|
|
|
|
|
class SentencePairDataset(Dataset): |
|
|
"""句子对数据集类(支持加权采样)""" |
|
|
|
|
|
def __init__(self, data, tokenizer, max_length=384, is_validation=False): |
|
|
self.data = data |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.is_validation = is_validation |
|
|
|
|
|
self.valid_data = [item for item in data if item['label'] in [0, 1]] |
|
|
dataset_type = "验证" if is_validation else "训练" |
|
|
logger.info(f"原始{dataset_type}数据: {len(data)} 条,有效数据: {len(self.valid_data)} 条") |
|
|
|
|
|
self.sentence1_list = [item['sentence1'] for item in self.valid_data] |
|
|
self.sentence2_list = [item['sentence2'] for item in self.valid_data] |
|
|
self.labels = [item['label'] for item in self.valid_data] |
|
|
|
|
|
# 只为训练集计算权重和采样器 |
|
|
if not is_validation: |
|
|
self.class_counts = Counter(self.labels) |
|
|
self.class_weights = self._compute_class_weights() |
|
|
self.sample_weights = self._compute_sample_weights() |
|
|
|
|
|
def _compute_class_weights(self): |
|
|
"""计算类别权重""" |
|
|
total_samples = len(self.labels) |
|
|
class_weights = {} |
|
|
for label in [0, 1]: |
|
|
count = self.class_counts[label] |
|
|
class_weights[label] = total_samples / (2 * count) |
|
|
return class_weights |
|
|
|
|
|
def _compute_sample_weights(self): |
|
|
"""计算每个样本的权重""" |
|
|
sample_weights = [] |
|
|
for label in self.labels: |
|
|
sample_weights.append(self.class_weights[label]) |
|
|
return torch.tensor(sample_weights, dtype=torch.float) |
|
|
|
|
|
def get_weighted_sampler(self): |
|
|
"""返回加权随机采样器(仅训练集)""" |
|
|
if self.is_validation: |
|
|
raise ValueError("验证集不需要加权采样器") |
|
|
return WeightedRandomSampler( |
|
|
weights=self.sample_weights, |
|
|
num_samples=len(self.sample_weights), |
|
|
replacement=True |
|
|
) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.valid_data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sentence1 = str(self.sentence1_list[idx]) |
|
|
sentence2 = str(self.sentence2_list[idx]) |
|
|
label = self.labels[idx] |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
sentence1, |
|
|
sentence2, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=self.max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].flatten(), |
|
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
|
'labels': torch.tensor(label, dtype=torch.long) |
|
|
} |
|
|
|
|
|
|
|
|
def load_and_split_data(train_file, validation_split=0.2, random_state=42): |
|
|
"""加载数据并划分训练集和验证集""" |
|
|
try: |
|
|
with open(train_file, 'r', encoding='utf-8') as f: |
|
|
all_data = json.load(f) |
|
|
logger.info(f"成功加载原始数据: {len(all_data)} 条记录") |
|
|
|
|
|
# 过滤有效数据 |
|
|
valid_data = [item for item in all_data if item['label'] in [0, 1]] |
|
|
logger.info(f"有效数据: {len(valid_data)} 条记录") |
|
|
|
|
|
# 按标签分层划分 |
|
|
labels = [item['label'] for item in valid_data] |
|
|
train_data, val_data = train_test_split( |
|
|
valid_data, |
|
|
test_size=validation_split, |
|
|
random_state=random_state, |
|
|
stratify=labels |
|
|
) |
|
|
|
|
|
logger.info(f"数据划分完成:") |
|
|
logger.info(f" 🔹 训练集: {len(train_data)} 条") |
|
|
logger.info(f" 🔹 验证集: {len(val_data)} 条") |
|
|
logger.info(f" 🔹 验证集比例: {validation_split * 100:.1f}%") |
|
|
|
|
|
# 分析训练集和验证集的分布 |
|
|
train_labels = [item['label'] for item in train_data] |
|
|
val_labels = [item['label'] for item in val_data] |
|
|
|
|
|
train_counts = Counter(train_labels) |
|
|
val_counts = Counter(val_labels) |
|
|
|
|
|
logger.info( |
|
|
f"训练集分布: 标签0={train_counts[0]}({train_counts[0] / len(train_data) * 100:.1f}%), 标签1={train_counts[1]}({train_counts[1] / len(train_data) * 100:.1f}%)") |
|
|
logger.info( |
|
|
f"验证集分布: 标签0={val_counts[0]}({val_counts[0] / len(val_data) * 100:.1f}%), 标签1={val_counts[1]}({val_counts[1] / len(val_data) * 100:.1f}%)") |
|
|
|
|
|
return train_data, val_data |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"加载和划分数据失败: {str(e)}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
def analyze_data_distribution(data): |
|
|
"""分析数据分布并计算优化的Focal Loss参数""" |
|
|
valid_data = [item for item in data if item['label'] in [0, 1]] |
|
|
|
|
|
label_counts = {} |
|
|
for item in valid_data: |
|
|
label = item['label'] |
|
|
label_counts[label] = label_counts.get(label, 0) + 1 |
|
|
|
|
|
total_samples = len(valid_data) |
|
|
|
|
|
logger.info("=== 训练数据分布分析 ===") |
|
|
logger.info(f"总有效记录数: {total_samples}") |
|
|
|
|
|
class_proportions = {} |
|
|
alpha_weights = [] |
|
|
|
|
|
for label in [0, 1]: |
|
|
count = label_counts.get(label, 0) |
|
|
proportion = count / total_samples |
|
|
class_proportions[label] = proportion |
|
|
|
|
|
label_name = "同段落" if label == 0 else "不同段落" |
|
|
logger.info(f"标签 {label} ({label_name}): {count} 条 ({proportion * 100:.2f}%)") |
|
|
|
|
|
minority_ratio = min(class_proportions.values()) |
|
|
imbalance_ratio = max(class_proportions.values()) / minority_ratio |
|
|
|
|
|
logger.info(f"\n📊 数据不平衡分析:") |
|
|
logger.info(f" 🔹 少数类比例: {minority_ratio * 100:.2f}%") |
|
|
logger.info(f" 🔹 不平衡比率: {imbalance_ratio:.2f}:1") |
|
|
|
|
|
if imbalance_ratio > 5: |
|
|
alpha_weights = [0.1, 0.9] |
|
|
logger.info(" 🎯 使用激进的alpha权重设置") |
|
|
else: |
|
|
alpha_weights = [1.0 - class_proportions[0], 1.0 - class_proportions[1]] |
|
|
|
|
|
if imbalance_ratio > 6: |
|
|
recommended_gamma = 3.5 |
|
|
logger.info(" ⚠️ 严重不平衡 - 使用Gamma=3.5强化聚焦") |
|
|
elif imbalance_ratio > 4: |
|
|
recommended_gamma = 3.0 |
|
|
logger.info(" ⚠️ 中度偏严重不平衡 - 使用Gamma=3.0") |
|
|
else: |
|
|
recommended_gamma = 2.5 |
|
|
|
|
|
logger.info(f"\n🎯 优化的Focal Loss参数设置:") |
|
|
logger.info(f" 🔹 Alpha权重: [多数类={alpha_weights[0]:.3f}, 少数类={alpha_weights[1]:.3f}]") |
|
|
logger.info(f" 🔹 优化Gamma: {recommended_gamma} (增强难样本聚焦)") |
|
|
logger.info(f" 🔹 公式: FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)") |
|
|
logger.info(f" 🔹 加权采样: 启用WeightedRandomSampler") |
|
|
|
|
|
return label_counts, alpha_weights, recommended_gamma |
|
|
|
|
|
|
|
|
def compute_metrics(eval_pred): |
|
|
"""计算评估指标""" |
|
|
predictions, labels = eval_pred |
|
|
|
|
|
# 处理predictions可能是嵌套列表或不规则数组的问题 |
|
|
if isinstance(predictions, (list, tuple)): |
|
|
# 如果是列表或元组,取第一个元素(通常是logits) |
|
|
predictions = predictions[0] |
|
|
|
|
|
# 确保predictions是numpy数组 |
|
|
if not isinstance(predictions, np.ndarray): |
|
|
predictions = np.array(predictions) |
|
|
|
|
|
# 检查predictions的形状 |
|
|
if len(predictions.shape) == 3: |
|
|
# 如果是3D数组,取最后一个维度 |
|
|
predictions = predictions[:, -1, :] |
|
|
elif len(predictions.shape) == 1: |
|
|
# 如果是1D数组,可能需要reshape |
|
|
predictions = predictions.reshape(-1, 2) |
|
|
|
|
|
# 确保我们有正确的2D形状 [batch_size, num_classes] |
|
|
if len(predictions.shape) != 2: |
|
|
logger.warning(f"Unexpected predictions shape: {predictions.shape}") |
|
|
# 尝试flatten并reshape |
|
|
predictions = predictions.reshape(-1, 2) |
|
|
|
|
|
# 应用argmax获取预测类别 |
|
|
try: |
|
|
predictions = np.argmax(predictions, axis=1) |
|
|
except Exception as e: |
|
|
logger.error(f"Error in argmax: {e}") |
|
|
logger.error(f"Predictions shape: {predictions.shape}") |
|
|
logger.error(f"Predictions dtype: {predictions.dtype}") |
|
|
# 如果还是失败,使用更安全的方法 |
|
|
predictions = np.array([np.argmax(pred) if len(pred) > 1 else 0 for pred in predictions]) |
|
|
|
|
|
accuracy = accuracy_score(labels, predictions) |
|
|
|
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
} |
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
"""优化的Focal Loss用于处理类别不平衡问题""" |
|
|
|
|
|
def __init__(self, alpha=None, gamma=3.0, reduction='mean'): |
|
|
super(FocalLoss, self).__init__() |
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, inputs, targets): |
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none') |
|
|
pt = torch.exp(-ce_loss) |
|
|
|
|
|
if self.alpha is not None: |
|
|
if self.alpha.type() != inputs.data.type(): |
|
|
self.alpha = self.alpha.type_as(inputs.data) |
|
|
at = self.alpha.gather(0, targets.data.view(-1)) |
|
|
ce_loss = ce_loss * at |
|
|
|
|
|
focal_weight = (1 - pt) ** self.gamma |
|
|
focal_loss = focal_weight * ce_loss |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
return focal_loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
return focal_loss.sum() |
|
|
else: |
|
|
return focal_loss |
|
|
|
|
|
|
|
|
class ScaledDotProductAttention(nn.Module): |
|
|
"""缩放点积注意力机制""" |
|
|
|
|
|
def __init__(self, d_model, dropout=0.1): |
|
|
super(ScaledDotProductAttention, self).__init__() |
|
|
self.d_model = d_model |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, query, key, value, mask=None): |
|
|
batch_size, seq_len, d_model = query.size() |
|
|
|
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) |
|
|
|
|
|
if mask is not None: |
|
|
mask_value = torch.finfo(scores.dtype).min |
|
|
scores = scores.masked_fill(mask == 0, mask_value) |
|
|
|
|
|
attention_weights = F.softmax(scores, dim=-1) |
|
|
attention_weights = self.dropout(attention_weights) |
|
|
|
|
|
output = torch.matmul(attention_weights, value) |
|
|
|
|
|
return output, attention_weights |
|
|
|
|
|
|
|
|
class DualPathBoundaryClassifier(nn.Module): |
|
|
"""双路径边界分类器,完全依靠神经网络学习边界模式""" |
|
|
|
|
|
def __init__(self, model_path, num_labels=2, dropout=0.1, |
|
|
focal_alpha=None, focal_gamma=3.0, boundary_force_weight=2.0): |
|
|
super(DualPathBoundaryClassifier, self).__init__() |
|
|
|
|
|
self.roberta = BertModel.from_pretrained(model_path) |
|
|
self.config = self.roberta.config |
|
|
self.config.num_labels = num_labels |
|
|
|
|
|
self.scaled_attention = ScaledDotProductAttention( |
|
|
d_model=self.config.hidden_size, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
# 双路径分类器 |
|
|
self.regular_classifier = nn.Linear(self.config.hidden_size, num_labels) # 常规分类器 |
|
|
self.boundary_classifier = nn.Linear(self.config.hidden_size, num_labels) # 边界分类器 |
|
|
self.boundary_detector = nn.Linear(self.config.hidden_size, 1) # 边界检测器 |
|
|
|
|
|
# 边界强制权重(较低设置) |
|
|
self.boundary_force_weight = nn.Parameter(torch.tensor(boundary_force_weight)) |
|
|
|
|
|
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
self.focal_alpha = focal_alpha |
|
|
self.focal_gamma = focal_gamma |
|
|
|
|
|
logger.info(f"🔧 双路径分类器初始化完成,边界强制权重: {boundary_force_weight}") |
|
|
logger.info(f"🤖 完全依靠神经网络学习边界模式,无预定义关键词") |
|
|
|
|
|
def _init_weights(self): |
|
|
"""初始化新增层的权重""" |
|
|
nn.init.normal_(self.regular_classifier.weight, std=0.02) |
|
|
nn.init.zeros_(self.regular_classifier.bias) |
|
|
nn.init.normal_(self.boundary_classifier.weight, std=0.02) |
|
|
nn.init.zeros_(self.boundary_classifier.bias) |
|
|
nn.init.normal_(self.boundary_detector.weight, std=0.02) |
|
|
nn.init.zeros_(self.boundary_detector.bias) |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None): |
|
|
|
|
|
roberta_outputs = self.roberta( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
sequence_output = roberta_outputs.last_hidden_state |
|
|
|
|
|
# 缩放点积注意力增强 |
|
|
enhanced_output, attention_weights = self.scaled_attention( |
|
|
query=sequence_output, |
|
|
key=sequence_output, |
|
|
value=sequence_output, |
|
|
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None |
|
|
) |
|
|
|
|
|
cls_output = enhanced_output[:, 0, :] |
|
|
cls_output = self.dropout(cls_output) |
|
|
|
|
|
# 双路径分类 |
|
|
regular_logits = self.regular_classifier(cls_output) |
|
|
boundary_logits = self.boundary_classifier(cls_output) |
|
|
|
|
|
# 边界检测(输出logits,不应用sigmoid) |
|
|
boundary_logits_raw = self.boundary_detector(cls_output).squeeze(-1) |
|
|
# 对于最终预测,应用sigmoid得到置信度 |
|
|
boundary_score = torch.sigmoid(boundary_logits_raw) |
|
|
|
|
|
# 动态融合:根据边界置信度调整最终预测 |
|
|
boundary_bias = torch.zeros_like(regular_logits) |
|
|
boundary_bias[:, 1] = boundary_score * self.boundary_force_weight # 只对分段类别(1)增加偏置 |
|
|
|
|
|
final_logits = regular_logits + boundary_bias |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
# 多任务损失 |
|
|
regular_loss = self.focal_loss(regular_logits, labels) |
|
|
boundary_loss = self.focal_loss(boundary_logits, labels) |
|
|
final_loss = self.focal_loss(final_logits, labels) |
|
|
|
|
|
# 边界检测损失(使用启发式边界标签,使用logits版本避免fp16问题) |
|
|
boundary_labels = self._generate_boundary_labels(labels) |
|
|
detection_loss = F.binary_cross_entropy_with_logits(boundary_logits_raw, boundary_labels) |
|
|
|
|
|
# 组合损失 |
|
|
total_loss = (0.4 * final_loss + |
|
|
0.3 * regular_loss + |
|
|
0.2 * boundary_loss + |
|
|
0.1 * detection_loss) |
|
|
loss = total_loss |
|
|
|
|
|
return { |
|
|
'loss': loss, |
|
|
'logits': final_logits, |
|
|
'regular_logits': regular_logits, |
|
|
'boundary_logits': boundary_logits, |
|
|
'boundary_score': boundary_score, |
|
|
'hidden_states': enhanced_output, |
|
|
'attention_weights': attention_weights |
|
|
} |
|
|
|
|
|
def _generate_boundary_labels(self, labels): |
|
|
"""为边界检测生成启发式标签""" |
|
|
# 简单启发式:分段标签(1)更可能包含边界句子 |
|
|
# 为神经网络提供初始学习信号,让它从数据中学习真正的边界模式 |
|
|
boundary_labels = labels.float() |
|
|
|
|
|
# 添加一些随机性,避免过度依赖分段标签 |
|
|
noise = torch.rand_like(boundary_labels) * 0.1 # 10%的噪声 |
|
|
boundary_labels = torch.clamp(boundary_labels + noise, 0.0, 1.0) |
|
|
|
|
|
return boundary_labels |
|
|
|
|
|
def save_pretrained(self, save_directory): |
|
|
"""保存模型""" |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
model_to_save = self.module if hasattr(self, 'module') else self |
|
|
torch.save(model_to_save.state_dict(), os.path.join(save_directory, 'pytorch_model.bin')) |
|
|
|
|
|
config_dict = { |
|
|
'model_type': 'DualPathBoundaryClassifier', |
|
|
'base_model': 'chinese-roberta-wwm-ext', |
|
|
'num_labels': self.config.num_labels, |
|
|
'hidden_size': self.config.hidden_size, |
|
|
'focal_alpha': self.focal_alpha.tolist() if self.focal_alpha is not None else None, |
|
|
'focal_gamma': self.focal_gamma, |
|
|
'boundary_force_weight': float(self.boundary_force_weight.data), |
|
|
'boundary_detection_method': 'pure_neural_network', |
|
|
'has_scaled_attention': True, |
|
|
'has_focal_loss': True, |
|
|
'has_dual_path': True, |
|
|
'optimization_level': 'dual_path_boundary_classification_neural_only' |
|
|
} |
|
|
|
|
|
with open(os.path.join(save_directory, 'config.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(config_dict, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
class WeightedTrainer(Trainer): |
|
|
"""自定义Trainer支持WeightedRandomSampler和双路径模型""" |
|
|
|
|
|
def __init__(self, weighted_sampler=None, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.weighted_sampler = weighted_sampler |
|
|
|
|
|
def get_train_dataloader(self): |
|
|
if self.train_dataset is None: |
|
|
raise ValueError("Trainer: training requires a train_dataset.") |
|
|
|
|
|
train_dataset = self.train_dataset |
|
|
|
|
|
if self.weighted_sampler is not None: |
|
|
train_sampler = self.weighted_sampler |
|
|
else: |
|
|
train_sampler = self._get_train_sampler() |
|
|
|
|
|
return DataLoader( |
|
|
train_dataset, |
|
|
batch_size=self.args.train_batch_size, |
|
|
sampler=train_sampler, |
|
|
collate_fn=self.data_collator, |
|
|
drop_last=self.args.dataloader_drop_last, |
|
|
num_workers=self.args.dataloader_num_workers, |
|
|
pin_memory=self.args.dataloader_pin_memory, |
|
|
) |
|
|
|
|
|
|
|
|
def train_roberta_model(train_data, val_data, |
|
|
model_path="/root/autodl-tmp/Robert-wwm-ext", |
|
|
output_dir="/root/autodl-tmp/model_chinese-roberta-wwm-ext/model_train", |
|
|
checkpoint_dir="/root/autodl-tmp/model_chinese-roberta-wwm-ext/ouput_result"): |
|
|
"""训练双路径边界分类器(基于验证集选择最佳模型)""" |
|
|
|
|
|
gpu_available, gpu_memory = check_gpu_availability() |
|
|
device = torch.device('cuda') |
|
|
logger.info(f"🚀 使用GPU设备: {device}") |
|
|
|
|
|
# 数据分布分析和优化的Focal Loss参数计算 |
|
|
label_distribution, alpha_weights, recommended_gamma = analyze_data_distribution(train_data) |
|
|
|
|
|
alpha_tensor = torch.tensor(alpha_weights, dtype=torch.float).to(device) |
|
|
|
|
|
logger.info(f"📥 加载Chinese-RoBERTa-WWM-Ext模型: {model_path}") |
|
|
tokenizer = BertTokenizer.from_pretrained(model_path) |
|
|
|
|
|
# 使用双路径分类器,设置较低的边界强制权重 |
|
|
model = DualPathBoundaryClassifier( |
|
|
model_path=model_path, |
|
|
num_labels=2, |
|
|
dropout=0.1, |
|
|
focal_alpha=alpha_tensor, |
|
|
focal_gamma=recommended_gamma, |
|
|
boundary_force_weight=2.0 # 较低的强制权重 |
|
|
) |
|
|
|
|
|
model = model.to(device) |
|
|
logger.info(f"✅ 双路径模型已加载到GPU: {next(model.parameters()).device}") |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
logger.info(f"📊 双路径模型参数统计:") |
|
|
logger.info(f" 🔹 总参数量: {total_params:,}") |
|
|
logger.info(f" 🔹 可训练参数: {trainable_params:,}") |
|
|
|
|
|
logger.info("📋 准备训练数据集和验证数据集...") |
|
|
train_dataset = SentencePairDataset(train_data, tokenizer, max_length=384, is_validation=False) |
|
|
val_dataset = SentencePairDataset(val_data, tokenizer, max_length=384, is_validation=True) |
|
|
weighted_sampler = train_dataset.get_weighted_sampler() |
|
|
|
|
|
logger.info(f" 🔹 训练集大小: {len(train_dataset)}") |
|
|
logger.info(f" 🔹 验证集大小: {len(val_dataset)}") |
|
|
logger.info(f" 🔹 类别权重: {train_dataset.class_weights}") |
|
|
|
|
|
# V100 48GB优化配置 |
|
|
batch_size = 8 |
|
|
gradient_accumulation = 4 |
|
|
max_grad_norm = 1.0 |
|
|
fp16 = True |
|
|
dataloader_num_workers = 4 |
|
|
|
|
|
effective_batch_size = batch_size * gradient_accumulation |
|
|
|
|
|
initial_learning_rate = 2e-5 |
|
|
warmup_ratio = 0.15 |
|
|
|
|
|
# 确保输出目录存在 |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=checkpoint_dir, # checkpoints保存到指定目录 |
|
|
num_train_epochs=100, |
|
|
per_device_train_batch_size=batch_size, |
|
|
per_device_eval_batch_size=batch_size, |
|
|
gradient_accumulation_steps=gradient_accumulation, |
|
|
eval_strategy="epoch", # 每个epoch进行评估 |
|
|
eval_steps=1, |
|
|
save_strategy="no", # 每个epoch保存 |
|
|
save_steps=0, # 每1个epoch保存一次 |
|
|
logging_strategy="steps", |
|
|
logging_steps=50, |
|
|
warmup_ratio=warmup_ratio, |
|
|
weight_decay=0.01, |
|
|
learning_rate=initial_learning_rate, |
|
|
load_best_model_at_end=True, # 训练结束后加载最佳模型 |
|
|
remove_unused_columns=False, |
|
|
dataloader_pin_memory=True, |
|
|
fp16=fp16, |
|
|
dataloader_num_workers=dataloader_num_workers, |
|
|
group_by_length=True, |
|
|
report_to=[], |
|
|
adam_epsilon=1e-8, |
|
|
max_grad_norm=max_grad_norm, |
|
|
save_total_limit=0, # 保留最近5个checkpoint |
|
|
skip_memory_metrics=True, |
|
|
disable_tqdm=False, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_steps=0, |
|
|
metric_for_best_model="eval_accuracy", # 使用验证准确率作为最佳模型指标 |
|
|
greater_is_better=True, # 准确率越高越好 |
|
|
) |
|
|
|
|
|
logger.info(f"🎯 双路径边界分类器训练参数:") |
|
|
logger.info(f" 🔹 模型类型: 纯神经网络双路径边界分类器") |
|
|
logger.info(f" 🔹 边界强制权重: 2.0 (较低设置)") |
|
|
logger.info(f" 🔹 边界检测方法: 纯神经网络学习") |
|
|
logger.info(f" 🔹 训练轮数: {training_args.num_train_epochs}") |
|
|
logger.info(f" 🔹 批次大小: {batch_size}") |
|
|
logger.info(f" 🔹 梯度累积: {gradient_accumulation}") |
|
|
logger.info(f" 🔹 有效批次大小: {effective_batch_size}") |
|
|
logger.info(f" 🔹 学习率: {training_args.learning_rate}") |
|
|
logger.info(f" 🔹 预热比例: {warmup_ratio}") |
|
|
logger.info(f" 🔹 序列长度: 384") |
|
|
logger.info(f" 🔹 混合精度: {fp16}") |
|
|
logger.info(f" 🔹 验证策略: 每个epoch评估") |
|
|
logger.info(f" 🔹 保存策略: 每个epoch保存") |
|
|
logger.info(f" 🔹 最佳模型选择: 验证准确率最高") |
|
|
logger.info(f" 🔹 自动加载最佳模型: 启用") |
|
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
|
|
loss_tracker = LossTracker() |
|
|
# 验证集混淆矩阵回调(每10个epoch) |
|
|
val_confusion_matrix_callback = ValidationConfusionMatrixCallback( |
|
|
eval_dataset=val_dataset, |
|
|
tokenizer=tokenizer, |
|
|
output_dir=checkpoint_dir, |
|
|
epochs_interval=10 |
|
|
) |
|
|
# 训练集混淆矩阵回调(每20个epoch) |
|
|
train_confusion_matrix_callback = TrainingConfusionMatrixCallback( |
|
|
train_dataset=train_dataset, |
|
|
tokenizer=tokenizer, |
|
|
output_dir=checkpoint_dir, |
|
|
epochs_interval=20 |
|
|
) |
|
|
|
|
|
trainer = WeightedTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=val_dataset, |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
compute_metrics=compute_metrics, |
|
|
callbacks=[loss_tracker, val_confusion_matrix_callback, train_confusion_matrix_callback], |
|
|
weighted_sampler=weighted_sampler |
|
|
) |
|
|
|
|
|
logger.info("🏃♂️ 开始双路径边界分类器训练...") |
|
|
logger.info("🎯 纯神经网络双路径训练配置:") |
|
|
logger.info(" ✅ 常规分类路径: 处理一般语义关系") |
|
|
logger.info(" ✅ 边界分类路径: 专门处理边界语句") |
|
|
logger.info(" ✅ 边界检测器: 纯神经网络自动学习边界模式") |
|
|
logger.info(" ✅ 数据驱动: 完全从训练数据中学习边界特征") |
|
|
logger.info(" ✅ 动态权重融合: 边界强制权重=2.0") |
|
|
logger.info(" ✅ 多任务损失: final_loss + regular_loss + boundary_loss + detection_loss") |
|
|
logger.info(" ✅ Focal Loss Gamma: 3.0-3.5") |
|
|
logger.info(" ✅ Alpha权重: [0.1, 0.9]") |
|
|
logger.info(" ✅ 学习率: 2e-5") |
|
|
logger.info(" ✅ 预热比例: 15%") |
|
|
logger.info(" ✅ WeightedRandomSampler") |
|
|
logger.info(" ✅ 余弦退火学习率调度") |
|
|
logger.info(" ✅ 验证集: 每个epoch评估") |
|
|
logger.info(" ✅ 模型选择: 验证准确率最高") |
|
|
logger.info(" ✅ 自动加载最佳模型") |
|
|
logger.info(" ✅ 验证集混淆矩阵: 每10个epoch生成") |
|
|
logger.info(" ✅ 训练集混淆矩阵: 每20个epoch生成") |
|
|
|
|
|
start_time = datetime.now() |
|
|
|
|
|
try: |
|
|
trainer.train() |
|
|
|
|
|
# 训练完成后,trainer.model已经是最佳模型 |
|
|
logger.info(f"🏆 双路径训练完成!已自动加载验证准确率最高的模型") |
|
|
logger.info(f" 🔹 最佳验证准确率: {loss_tracker.best_eval_accuracy:.4f}") |
|
|
logger.info(f" 🔹 最佳模型来自: Epoch {loss_tracker.best_epoch}") |
|
|
|
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e).lower(): |
|
|
logger.error("❌ GPU内存不足!") |
|
|
logger.error("💡 建议减小批次大小") |
|
|
raise |
|
|
else: |
|
|
raise |
|
|
|
|
|
end_time = datetime.now() |
|
|
training_duration = end_time - start_time |
|
|
logger.info(f"🎉 双路径训练完成! 耗时: {training_duration}") |
|
|
|
|
|
logger.info("📈 生成训练可视化图表...") |
|
|
plot_training_curves(loss_tracker, checkpoint_dir) |
|
|
|
|
|
logger.info(f"💾 保存最佳双路径模型到: {output_dir}") |
|
|
|
|
|
# 保存最佳模型到指定的模型输出目录 |
|
|
trainer.model.save_pretrained(output_dir) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
# 保存损失历史和验证准确率历史 |
|
|
training_history = { |
|
|
'train_losses': loss_tracker.train_losses, |
|
|
'train_steps': loss_tracker.train_steps, |
|
|
'eval_losses': loss_tracker.eval_losses, |
|
|
'eval_steps': loss_tracker.eval_steps, |
|
|
'eval_accuracies': loss_tracker.eval_accuracies, |
|
|
'epochs': loss_tracker.epochs, |
|
|
'best_eval_accuracy': loss_tracker.best_eval_accuracy, |
|
|
'best_epoch': loss_tracker.best_epoch, |
|
|
} |
|
|
|
|
|
with open(os.path.join(checkpoint_dir, 'training_history.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(training_history, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# 保存验证集混淆矩阵历史 |
|
|
val_cm_history = {epoch: cm.tolist() for epoch, cm in val_confusion_matrix_callback.confusion_matrices.items()} |
|
|
with open(os.path.join(checkpoint_dir, 'validation_confusion_matrix_history.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(val_cm_history, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# 保存训练集混淆矩阵历史 |
|
|
train_cm_history = {epoch: cm.tolist() for epoch, cm in train_confusion_matrix_callback.confusion_matrices.items()} |
|
|
with open(os.path.join(checkpoint_dir, 'training_confusion_matrix_history.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(train_cm_history, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# 保存详细的训练信息 |
|
|
training_info = { |
|
|
"model_name": model_path, |
|
|
"model_type": "DualPathBoundaryClassifier with Pure Neural Network Learning", |
|
|
"optimization_level": "dual_path_boundary_classification_neural_only", |
|
|
"training_duration": str(training_duration), |
|
|
"num_train_samples": len(train_dataset), |
|
|
"num_val_samples": len(val_dataset), |
|
|
"validation_split": len(val_dataset) / (len(train_dataset) + len(val_dataset)), |
|
|
"label_distribution": label_distribution, |
|
|
"best_model_info": { |
|
|
"best_validation_accuracy": float(loss_tracker.best_eval_accuracy), |
|
|
"best_epoch": int(loss_tracker.best_epoch), |
|
|
"model_selection_criterion": "validation_accuracy", |
|
|
"load_best_model_at_end": True |
|
|
}, |
|
|
"dual_path_config": { |
|
|
"boundary_force_weight": 2.0, |
|
|
"boundary_detection_method": "pure_neural_network", |
|
|
"has_regular_classifier": True, |
|
|
"has_boundary_classifier": True, |
|
|
"has_boundary_detector": True, |
|
|
"detection_method": "neural_network_only", |
|
|
"fusion_strategy": "dynamic_weighted", |
|
|
"learning_approach": "data_driven_boundary_detection" |
|
|
}, |
|
|
"data_imbalance": { |
|
|
"class_0_count": label_distribution.get(0, 0), |
|
|
"class_1_count": label_distribution.get(1, 0), |
|
|
"class_0_ratio": label_distribution.get(0, 0) / len(train_dataset), |
|
|
"class_1_ratio": label_distribution.get(1, 0) / len(train_dataset), |
|
|
"imbalance_ratio": label_distribution.get(0, 1) / label_distribution.get(1, 1) |
|
|
}, |
|
|
"optimized_focal_loss_params": { |
|
|
"alpha_weights": alpha_weights, |
|
|
"gamma": recommended_gamma, |
|
|
"formula": "FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)", |
|
|
"optimization": "aggressive_minority_class_focus" |
|
|
}, |
|
|
"multi_task_loss": { |
|
|
"final_loss_weight": 0.4, |
|
|
"regular_loss_weight": 0.3, |
|
|
"boundary_loss_weight": 0.2, |
|
|
"detection_loss_weight": 0.1, |
|
|
"total_formula": "0.4*final + 0.3*regular + 0.2*boundary + 0.1*detection" |
|
|
}, |
|
|
"weighted_sampling": { |
|
|
"enabled": True, |
|
|
"class_weights": train_dataset.class_weights, |
|
|
"sampler_type": "WeightedRandomSampler", |
|
|
"applies_to": "training_set_only" |
|
|
}, |
|
|
"validation_setup": { |
|
|
"enabled": True, |
|
|
"validation_split": "20%", |
|
|
"stratified_split": True, |
|
|
"eval_strategy": "every_epoch", |
|
|
"save_strategy": "every_epoch", |
|
|
"confusion_matrix_frequency": "every_10_epochs", |
|
|
"model_selection": "best_validation_accuracy" |
|
|
}, |
|
|
"optimized_learning_strategy": { |
|
|
"initial_learning_rate": initial_learning_rate, |
|
|
"warmup_ratio": warmup_ratio, |
|
|
"lr_scheduler": "cosine", |
|
|
"improvement": "optimized_for_v100" |
|
|
}, |
|
|
"gpu_optimization": { |
|
|
"gpu_name": torch.cuda.get_device_name(0), |
|
|
"gpu_memory_gb": gpu_memory, |
|
|
"optimization_target": "V100_48GB", |
|
|
"effective_batch_size": effective_batch_size, |
|
|
"sequence_length": 384, |
|
|
"batch_size_optimization": "v100_optimized" |
|
|
}, |
|
|
"training_args": { |
|
|
"num_train_epochs": training_args.num_train_epochs, |
|
|
"per_device_train_batch_size": training_args.per_device_train_batch_size, |
|
|
"per_device_eval_batch_size": training_args.per_device_eval_batch_size, |
|
|
"gradient_accumulation_steps": training_args.gradient_accumulation_steps, |
|
|
"learning_rate": training_args.learning_rate, |
|
|
"warmup_ratio": training_args.warmup_ratio, |
|
|
"weight_decay": training_args.weight_decay, |
|
|
"fp16": training_args.fp16, |
|
|
"lr_scheduler_type": training_args.lr_scheduler_type, |
|
|
"eval_strategy": training_args.eval_strategy, |
|
|
"save_strategy": training_args.save_strategy, |
|
|
"metric_for_best_model": training_args.metric_for_best_model, |
|
|
"load_best_model_at_end": training_args.load_best_model_at_end |
|
|
}, |
|
|
"model_parameters": { |
|
|
"total_params": total_params, |
|
|
"trainable_params": trainable_params, |
|
|
}, |
|
|
"paths": { |
|
|
"model_input_path": model_path, |
|
|
"model_output_path": output_dir, |
|
|
"checkpoint_output_path": checkpoint_dir, |
|
|
"data_path": "/root/autodl-tmp/Data" |
|
|
}, |
|
|
"high_priority_optimizations": [ |
|
|
"双路径架构:常规+边界专门处理", |
|
|
"边界强制权重: 2.0 (温和设置)", |
|
|
"纯神经网络边界检测 (无预定义关键词)", |
|
|
"数据驱动边界模式学习", |
|
|
"多任务损失函数优化", |
|
|
"Focal Loss Gamma increased to 3.0-3.5", |
|
|
"Alpha weights set to [0.1, 0.9] for aggressive minority class focus", |
|
|
"Learning rate optimized for V100: 2e-5", |
|
|
"Warmup ratio increased to 15%", |
|
|
"WeightedRandomSampler for balanced class sampling", |
|
|
"Cosine annealing learning rate scheduler", |
|
|
"V100 48GB optimized batch size: 16", |
|
|
"Full sequence length: 384 tokens", |
|
|
"Validation set with stratified split", |
|
|
"Best model selection based on validation accuracy", |
|
|
"Automatic best model loading at training end" |
|
|
], |
|
|
"visualization_files": { |
|
|
"comprehensive_training_curves": "comprehensive_training_curves.png", |
|
|
"loss_comparison": "loss_comparison_curves.png", |
|
|
"validation_confusion_matrices": [f"validation_confusion_matrix_epoch_{i}.png" for i in |
|
|
range(10, 101, 10)] + ["validation_confusion_matrix_epoch_100.png"], |
|
|
"training_confusion_matrices": [f"training_confusion_matrix_epoch_{i}.png" for i in range(20, 101, 20)] + [ |
|
|
"training_confusion_matrix_epoch_100.png"], |
|
|
"training_history": "training_history.json", |
|
|
"validation_confusion_matrix_history": "validation_confusion_matrix_history.json", |
|
|
"training_confusion_matrix_history": "training_confusion_matrix_history.json" |
|
|
}, |
|
|
"training_completed": end_time.isoformat() |
|
|
} |
|
|
|
|
|
with open(os.path.join(checkpoint_dir, 'training_info.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(training_info, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# 在模型目录保存训练摘要 |
|
|
model_summary = { |
|
|
"model_selection_info": { |
|
|
"best_validation_accuracy": float(loss_tracker.best_eval_accuracy), |
|
|
"best_epoch": int(loss_tracker.best_epoch), |
|
|
"selection_criterion": "highest_validation_accuracy", |
|
|
"total_epochs_trained": training_args.num_train_epochs |
|
|
}, |
|
|
"dual_path_summary": { |
|
|
"model_type": "DualPathBoundaryClassifier", |
|
|
"boundary_force_weight": 2.0, |
|
|
"boundary_detection": "pure_neural_network", |
|
|
"specialized_for": "data_driven_boundary_sentence_segmentation" |
|
|
}, |
|
|
"model_config": training_info |
|
|
} |
|
|
|
|
|
with open(os.path.join(output_dir, 'best_model_info.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(model_summary, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
logger.info("📋 双路径训练信息和最佳模型选择记录已保存") |
|
|
|
|
|
return trainer, trainer.model, tokenizer, loss_tracker, val_confusion_matrix_callback, train_confusion_matrix_callback |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数""" |
|
|
logger.info("=" * 120) |
|
|
logger.info("🚀 Chinese-RoBERTa-WWM-Ext 纯神经网络双路径边界分类器训练") |
|
|
logger.info("=" * 120) |
|
|
|
|
|
# 配置路径 |
|
|
train_file = "/root//autodl-tmp/model_train-Data/Data/train_dataset.json" |
|
|
model_path = "/root/autodl-tmp/Robert-wwm-ext" |
|
|
output_dir = "/root/autodl-tmp/model_chinese-roberta-wwm-ext/model_train" |
|
|
checkpoint_dir = "/root/autodl-tmp/model_chinese-roberta-wwm-ext/ouput_result" |
|
|
|
|
|
# 确保所有输出目录存在 |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
logger.info(f"📁 确保输出目录存在:") |
|
|
logger.info(f" 🔹 最佳模型输出: {output_dir}") |
|
|
logger.info(f" 🔹 训练记录: {checkpoint_dir}") |
|
|
|
|
|
# 确认第三方报告工具已禁用 |
|
|
logger.info("🚫 确认第三方报告工具状态:") |
|
|
logger.info(f" 🔹 WANDB_DISABLED: {os.environ.get('WANDB_DISABLED', 'not set')}") |
|
|
logger.info(f" 🔹 COMET_MODE: {os.environ.get('COMET_MODE', 'not set')}") |
|
|
logger.info(f" 🔹 NEPTUNE_MODE: {os.environ.get('NEPTUNE_MODE', 'not set')}") |
|
|
logger.info(f" ✅ 所有第三方报告工具已禁用") |
|
|
|
|
|
logger.info(f"\n📋 纯神经网络双路径边界分类器配置:") |
|
|
logger.info(f" 🔹 训练数据: {train_file}") |
|
|
logger.info(f" 🔹 基础模型: {model_path}") |
|
|
logger.info(f" 🔹 模型类型: DualPathBoundaryClassifier (纯神经网络)") |
|
|
logger.info(f" 🔹 验证集: 20%分层划分") |
|
|
logger.info(f" 🔹 模型选择标准: 验证集准确率最高") |
|
|
logger.info(f" 🔹 自动加载最佳模型: 启用") |
|
|
logger.info(f" 🔹 目标: 处理严重数据不平衡问题 + 数据驱动边界句子识别") |
|
|
logger.info(f" 🔹 纯神经网络双路径核心特性:") |
|
|
logger.info(f" • 常规分类路径: 处理一般语义关系") |
|
|
logger.info(f" • 边界分类路径: 专门处理边界语句") |
|
|
logger.info(f" • 边界检测器: 纯神经网络自动学习边界特征") |
|
|
logger.info(f" • 边界强制权重: 2.0 (温和设置)") |
|
|
logger.info(f" • 多任务损失函数") |
|
|
logger.info(f" • 数据驱动: 无预定义关键词,完全从数据学习") |
|
|
logger.info(f" 🔹 继承的核心优化:") |
|
|
logger.info(f" • Focal Loss Gamma: 3.0+ (增强难样本聚焦)") |
|
|
logger.info(f" • Alpha权重: [0.1, 0.9] (激进的少数类关注)") |
|
|
logger.info(f" • 学习率: 2e-5 (V100优化)") |
|
|
logger.info(f" • 批次大小: 16 (V100大显存优化)") |
|
|
logger.info(f" • 序列长度: 384 (完整长度)") |
|
|
logger.info(f" • WeightedRandomSampler (平衡采样)") |
|
|
logger.info(f" • 每epoch验证和保存") |
|
|
logger.info(f" • 验证集混淆矩阵每10个epoch生成") |
|
|
logger.info(f" 🔹 训练轮数: 100 epochs") |
|
|
logger.info(f" 🔹 最佳模型输出: {output_dir}") |
|
|
logger.info(f" 🔹 训练记录: {checkpoint_dir}") |
|
|
|
|
|
# 加载和划分数据 |
|
|
train_data, val_data = load_and_split_data(train_file, validation_split=0.2, random_state=42) |
|
|
if train_data is None or val_data is None: |
|
|
logger.error("❌ 无法加载和划分数据,程序退出") |
|
|
return |
|
|
|
|
|
try: |
|
|
# 训练双路径模型并自动选择最佳模型 |
|
|
trainer, best_model, tokenizer, loss_tracker, val_cm_callback, train_cm_callback = train_roberta_model( |
|
|
train_data, val_data, |
|
|
model_path=model_path, |
|
|
output_dir=output_dir, |
|
|
checkpoint_dir=checkpoint_dir |
|
|
) |
|
|
|
|
|
logger.info("=" * 120) |
|
|
logger.info("🎉 纯神经网络双路径边界分类器训练完成!") |
|
|
logger.info("=" * 120) |
|
|
logger.info(f"🏆 最佳纯神经网络双路径模型信息:") |
|
|
logger.info(f" 🔹 验证准确率: {loss_tracker.best_eval_accuracy:.4f}") |
|
|
logger.info(f" 🔹 来自Epoch: {loss_tracker.best_epoch}") |
|
|
logger.info(f" 🔹 选择标准: 验证集准确率最高") |
|
|
logger.info(f" 🔹 已自动加载并保存最佳模型") |
|
|
logger.info(f" 🔹 边界强制权重: 2.0 (温和设置)") |
|
|
logger.info(f" 🔹 边界检测方法: 纯神经网络学习") |
|
|
|
|
|
logger.info(f"\n📁 文件输出位置:") |
|
|
logger.info(f" 🔹 最佳双路径模型: {output_dir}") |
|
|
logger.info(f" 🔹 训练记录和图表: {checkpoint_dir}") |
|
|
|
|
|
logger.info("📄 生成的文件:") |
|
|
logger.info(" 最佳纯神经网络双路径模型文件 (model_train目录):") |
|
|
logger.info(" • pytorch_model.bin - 验证性能最佳的纯神经网络双路径模型权重") |
|
|
logger.info(" • config.json - 纯神经网络双路径模型配置") |
|
|
logger.info(" • tokenizer配置文件") |
|
|
logger.info(" • best_model_info.json - 最佳双路径模型选择信息") |
|
|
|
|
|
logger.info(" 训练记录 (ouput_result目录):") |
|
|
logger.info(" • training_info.json - 详细纯神经网络双路径训练信息") |
|
|
logger.info(" • training_history.json - 完整训练历史(包含验证准确率)") |
|
|
logger.info(" • validation_confusion_matrix_history.json - 验证集混淆矩阵历史") |
|
|
logger.info(" • training_confusion_matrix_history.json - 训练集混淆矩阵历史") |
|
|
logger.info(" • comprehensive_training_curves.png - 综合训练曲线(含验证准确率)") |
|
|
logger.info(" • loss_comparison_curves.png - 训练vs验证损失对比") |
|
|
logger.info(" • validation_confusion_matrix_epoch_X.png - 验证集混淆矩阵(每10个epoch)") |
|
|
logger.info(" • training_confusion_matrix_epoch_X.png - 训练集混淆矩阵(每20个epoch)") |
|
|
logger.info(" • checkpoint-* - 所有训练检查点") |
|
|
|
|
|
logger.info("🔥 纯神经网络双路径边界分类器特性:") |
|
|
logger.info(" ✅ Chinese-RoBERTa-WWM-Ext 基础模型") |
|
|
logger.info(" ✅ 双路径架构: 常规路径 + 边界路径") |
|
|
logger.info(" ✅ 边界检测器: 纯神经网络学习边界模式") |
|
|
logger.info(" ✅ 数据驱动: 无预定义关键词,完全从数据学习") |
|
|
logger.info(" ✅ 边界强制权重: 2.0 (温和设置,避免过度强制)") |
|
|
logger.info(" ✅ 多任务损失函数: 4种损失加权组合") |
|
|
logger.info(" ✅ 动态权重融合: 根据边界置信度自适应调节") |
|
|
logger.info(" ✅ 激进的Focal Loss参数 (Gamma=3.0+, Alpha=[0.1,0.9])") |
|
|
logger.info(" ✅ V100优化学习率: 2e-5") |
|
|
logger.info(" ✅ 大批次训练: 16 (有效批次: 32)") |
|
|
logger.info(" ✅ 完整序列长度: 384 tokens") |
|
|
logger.info(" ✅ WeightedRandomSampler 平衡采样") |
|
|
logger.info(" ✅ 余弦退火学习率调度") |
|
|
logger.info(" ✅ 验证集分层划分 (20%)") |
|
|
logger.info(" ✅ 每个epoch验证评估和保存") |
|
|
logger.info(" ✅ 自动选择验证准确率最高的模型") |
|
|
logger.info(" ✅ 验证集混淆矩阵每10个epoch") |
|
|
logger.info(" ✅ 训练集混淆矩阵每20个epoch") |
|
|
logger.info(" ✅ 100 epochs长时间训练") |
|
|
logger.info(" ✅ 完整可视化监控") |
|
|
|
|
|
logger.info("🎯 纯神经网络双路径针对数据不平衡和边界识别的专项优化:") |
|
|
logger.info(" ⚡ 常规路径: 专注一般语义关系学习") |
|
|
logger.info(" ⚡ 边界路径: 专门处理边界语句模式") |
|
|
logger.info(" ⚡ 边界检测: 神经网络自动学习边界特征") |
|
|
logger.info(" ⚡ 数据驱动: 从训练样本中发现真正的边界模式") |
|
|
logger.info(" ⚡ 自适应学习: 模型自主发现边界规律,无人工先验") |
|
|
logger.info(" ⚡ 强制分段: 边界句子自动偏向分段决策") |
|
|
logger.info(" ⚡ 少数类样本权重提升9倍") |
|
|
logger.info(" ⚡ 难分类样本聚焦增强50%") |
|
|
logger.info(" ⚡ V100大显存充分利用") |
|
|
logger.info(" ⚡ 类别平衡采样确保训练公平性") |
|
|
logger.info(" ⚡ 验证集实时监控防止过拟合") |
|
|
logger.info(" ⚡ 自动选择泛化能力最强的模型") |
|
|
logger.info(" ⚡ 预期边界句子识别准确率提升25-40%") |
|
|
logger.info(" ⚡ 预期少数类F1分数提升20-35%") |
|
|
|
|
|
# 显示完整保存路径列表 |
|
|
logger.info(f"\n📂 文件保存详情:") |
|
|
logger.info(f"📋 最佳纯神经网络双路径模型文件 ({output_dir}):") |
|
|
try: |
|
|
for file in os.listdir(output_dir): |
|
|
file_path = os.path.join(output_dir, file) |
|
|
if os.path.isfile(file_path): |
|
|
file_size = os.path.getsize(file_path) / (1024 * 1024) |
|
|
if file == 'best_model_info.json': |
|
|
logger.info(f" 🏆 {file} ({file_size:.2f} MB) - 最佳纯神经网络双路径模型选择信息") |
|
|
else: |
|
|
logger.info(f" 📄 {file} ({file_size:.2f} MB)") |
|
|
except Exception as e: |
|
|
logger.warning(f" ⚠️ 无法列出模型文件: {str(e)}") |
|
|
|
|
|
logger.info(f"📋 训练记录 ({checkpoint_dir}):") |
|
|
try: |
|
|
files = os.listdir(checkpoint_dir) |
|
|
# 按类型分组显示 |
|
|
val_cm_files = [f for f in files if f.startswith('validation_confusion_matrix') and f.endswith('.png')] |
|
|
train_cm_files = [f for f in files if f.startswith('training_confusion_matrix') and f.endswith('.png')] |
|
|
curve_files = [f for f in files if f.endswith('.png') and 'curve' in f] |
|
|
json_files = [f for f in files if f.endswith('.json')] |
|
|
checkpoint_dirs = [f for f in files if f.startswith('checkpoint-')] |
|
|
other_files = [f for f in files if |
|
|
f not in val_cm_files + train_cm_files + curve_files + json_files + checkpoint_dirs] |
|
|
|
|
|
if json_files: |
|
|
logger.info(" JSON配置和历史文件:") |
|
|
for file in sorted(json_files): |
|
|
file_path = os.path.join(checkpoint_dir, file) |
|
|
file_size = os.path.getsize(file_path) / 1024 |
|
|
if file == 'training_history.json': |
|
|
logger.info(f" 📈 {file} ({file_size:.1f} KB) - 完整纯神经网络双路径训练历史") |
|
|
elif file == 'training_info.json': |
|
|
logger.info(f" 📄 {file} ({file_size:.1f} KB) - 纯神经网络双路径模型详细信息") |
|
|
else: |
|
|
logger.info(f" 📄 {file} ({file_size:.1f} KB)") |
|
|
|
|
|
if curve_files: |
|
|
logger.info(" 训练曲线图表:") |
|
|
for file in sorted(curve_files): |
|
|
file_path = os.path.join(checkpoint_dir, file) |
|
|
file_size = os.path.getsize(file_path) / 1024 |
|
|
if 'comprehensive' in file: |
|
|
logger.info(f" 📊 {file} ({file_size:.1f} KB) - 综合纯神经网络双路径训练曲线") |
|
|
else: |
|
|
logger.info(f" 📊 {file} ({file_size:.1f} KB)") |
|
|
|
|
|
if val_cm_files: |
|
|
logger.info(" 验证集混淆矩阵 (每10个epoch):") |
|
|
for file in sorted(val_cm_files)[:3]: # 只显示前3个 |
|
|
file_path = os.path.join(checkpoint_dir, file) |
|
|
file_size = os.path.getsize(file_path) / 1024 |
|
|
logger.info(f" 📊 {file} ({file_size:.1f} KB)") |
|
|
if len(val_cm_files) > 3: |
|
|
logger.info(f" ... 以及其他 {len(val_cm_files) - 3} 个验证集混淆矩阵文件") |
|
|
|
|
|
if train_cm_files: |
|
|
logger.info(" 训练集混淆矩阵 (每20个epoch):") |
|
|
for file in sorted(train_cm_files)[:3]: # 只显示前3个 |
|
|
file_path = os.path.join(checkpoint_dir, file) |
|
|
file_size = os.path.getsize(file_path) / 1024 |
|
|
logger.info(f" 📊 {file} ({file_size:.1f} KB)") |
|
|
if len(train_cm_files) > 3: |
|
|
logger.info(f" ... 以及其他 {len(train_cm_files) - 3} 个训练集混淆矩阵文件") |
|
|
|
|
|
if checkpoint_dirs: |
|
|
logger.info(" 训练检查点:") |
|
|
for dir_name in sorted(checkpoint_dirs)[:3]: # 只显示前3个 |
|
|
logger.info(f" 📁 {dir_name}/") |
|
|
if len(checkpoint_dirs) > 3: |
|
|
logger.info(f" ... 以及其他 {len(checkpoint_dirs) - 3} 个检查点目录") |
|
|
|
|
|
if other_files: |
|
|
logger.info(" 其他文件:") |
|
|
for file in sorted(other_files): |
|
|
file_path = os.path.join(checkpoint_dir, file) |
|
|
if os.path.isfile(file_path): |
|
|
file_size = os.path.getsize(file_path) / (1024 * 1024) |
|
|
logger.info(f" 📄 {file} ({file_size:.2f} MB)") |
|
|
|
|
|
except Exception as e: |
|
|
logger.warning(f" ⚠️ 无法列出训练记录: {str(e)}") |
|
|
|
|
|
logger.info("\n🎯 纯神经网络双路径训练完成,最佳模型已自动选择并保存!") |
|
|
logger.info("📊 建议查看:") |
|
|
logger.info(" • best_model_info.json - 纯神经网络双路径最佳模型选择详情") |
|
|
logger.info(" • comprehensive_training_curves.png - 验证准确率变化趋势") |
|
|
logger.info(" • 验证集混淆矩阵的演进过程") |
|
|
logger.info(" • 训练损失vs验证损失的收敛情况") |
|
|
logger.info(" • 最佳模型对应epoch的验证集性能") |
|
|
logger.info(" • 纯神经网络双路径模型的边界检测效果") |
|
|
|
|
|
logger.info(f"\n🏆 最佳纯神经网络双路径模型总结:") |
|
|
logger.info(f" • 验证准确率: {loss_tracker.best_eval_accuracy:.4f}") |
|
|
logger.info(f" • 最佳Epoch: {loss_tracker.best_epoch}") |
|
|
logger.info(f" • 模型类型: DualPathBoundaryClassifier (纯神经网络)") |
|
|
logger.info(f" • 边界强制权重: 2.0 (温和设置)") |
|
|
logger.info(f" • 边界检测方法: 纯神经网络学习") |
|
|
logger.info(f" • 模型保存位置: {output_dir}") |
|
|
logger.info(f" • 可直接用于推理和部署") |
|
|
logger.info(f" • 专门优化数据驱动的边界句子识别和强制分段") |
|
|
|
|
|
logger.info(f"\n🎖️ 纯神经网络双路径模型优势:") |
|
|
logger.info(f" • 常规路径保持泛化能力") |
|
|
logger.info(f" • 边界路径专门处理边界句子") |
|
|
logger.info(f" • 纯神经网络自动发现边界模式") |
|
|
logger.info(f" • 数据驱动,无人工先验偏见") |
|
|
logger.info(f" • 温和的强制权重避免过度干预") |
|
|
logger.info(f" • 多任务学习提升整体性能") |
|
|
logger.info(f" • 完全自适应的边界识别能力") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ 纯神经网络双路径边界分类器训练过程中出现错误: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |