You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1169 lines
48 KiB

import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from transformers import (
BertTokenizer,
BertForSequenceClassification,
BertModel,
BertConfig,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
TrainerCallback
)
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import logging
import os
from datetime import datetime
import math
from collections import defaultdict, Counter
# 禁用wandb和其他第三方报告工具
os.environ["WANDB_DISABLED"] = "true"
os.environ["COMET_MODE"] = "disabled"
os.environ["NEPTUNE_MODE"] = "disabled"
# 设置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 设置matplotlib中文字体
plt.rcParams['font.sans-serif'] = ['DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
def check_gpu_availability():
"""检查GPU可用性"""
if not torch.cuda.is_available():
raise RuntimeError("❌ GPU不可用!此脚本需要GPU支持。")
gpu_count = torch.cuda.device_count()
gpu_name = torch.cuda.get_device_name(0)
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3
logger.info(f"✅ GPU检查通过!")
logger.info(f" 🔹 可用GPU数量: {gpu_count}")
logger.info(f" 🔹 GPU型号: {gpu_name}")
logger.info(f" 🔹 GPU内存: {gpu_memory:.1f} GB")
# V100优化设置
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True
return True, gpu_memory
class LossTracker(TrainerCallback):
"""损失跟踪回调类"""
def __init__(self):
self.train_losses = []
self.eval_losses = []
self.train_steps = []
self.eval_steps = []
self.current_epoch = 0
def on_log(self, args, state, control, logs=None, **kwargs):
if logs:
if 'loss' in logs:
self.train_losses.append(logs['loss'])
self.train_steps.append(state.global_step)
if 'eval_loss' in logs:
self.eval_losses.append(logs['eval_loss'])
self.eval_steps.append(state.global_step)
def on_epoch_end(self, args, state, control, **kwargs):
self.current_epoch = state.epoch
class ValidationConfusionMatrixCallback(TrainerCallback):
"""验证集混淆矩阵生成回调(每10个epoch)"""
def __init__(self, eval_dataset, tokenizer, output_dir, epochs_interval=10):
self.eval_dataset = eval_dataset
self.tokenizer = tokenizer
self.output_dir = output_dir
self.epochs_interval = epochs_interval
self.confusion_matrices = {}
def on_epoch_end(self, args, state, control, model=None, **kwargs):
current_epoch = int(state.epoch)
# 每10个epoch生成验证集混淆矩阵
if current_epoch % self.epochs_interval == 0 or current_epoch == args.num_train_epochs:
logger.info(f"📊 Generating validation confusion matrix for epoch {current_epoch}...")
model.eval()
predictions = []
true_labels = []
device = next(model.parameters()).device
with torch.no_grad():
for i in range(len(self.eval_dataset)):
item = self.eval_dataset[i]
input_ids = item['input_ids'].unsqueeze(0).to(device)
attention_mask = item['attention_mask'].unsqueeze(0).to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(outputs['logits'], dim=-1).cpu().item()
predictions.append(pred)
true_labels.append(item['labels'].item())
cm = confusion_matrix(true_labels, predictions)
self.confusion_matrices[current_epoch] = cm
self.save_confusion_matrix(cm, current_epoch)
model.train()
def save_confusion_matrix(self, cm, epoch):
"""保存验证集混淆矩阵图"""
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'],
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'])
plt.title(f'Validation Confusion Matrix - Epoch {epoch} (Resumed)')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
accuracy = np.trace(cm) / np.sum(cm)
plt.text(0.5, -0.15, f'Validation Accuracy: {accuracy:.4f}',
ha='center', transform=plt.gca().transAxes)
plt.tight_layout()
save_path = os.path.join(self.output_dir, f'validation_confusion_matrix_epoch_{epoch}_resumed.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f" 💾 Validation confusion matrix saved: {save_path}")
class TrainingConfusionMatrixCallback(TrainerCallback):
"""训练集混淆矩阵生成回调(每20个epoch)"""
def __init__(self, train_dataset, tokenizer, output_dir, epochs_interval=20):
self.train_dataset = train_dataset
self.tokenizer = tokenizer
self.output_dir = output_dir
self.epochs_interval = epochs_interval
self.confusion_matrices = {}
def on_epoch_end(self, args, state, control, model=None, **kwargs):
current_epoch = int(state.epoch)
if current_epoch % self.epochs_interval == 0 or current_epoch == args.num_train_epochs:
logger.info(f"📊 Generating training confusion matrix for epoch {current_epoch}...")
model.eval()
predictions = []
true_labels = []
device = next(model.parameters()).device
# 只使用训练集的一个子集来生成混淆矩阵,避免时间过长
subset_size = min(1000, len(self.train_dataset))
indices = np.random.choice(len(self.train_dataset), subset_size, replace=False)
with torch.no_grad():
for i in indices:
item = self.train_dataset[i]
input_ids = item['input_ids'].unsqueeze(0).to(device)
attention_mask = item['attention_mask'].unsqueeze(0).to(device)
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
pred = torch.argmax(outputs['logits'], dim=-1).cpu().item()
predictions.append(pred)
true_labels.append(item['labels'].item())
cm = confusion_matrix(true_labels, predictions)
self.confusion_matrices[current_epoch] = cm
self.save_confusion_matrix(cm, current_epoch)
model.train()
def save_confusion_matrix(self, cm, epoch):
"""保存训练集混淆矩阵图"""
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Greens',
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'],
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'])
plt.title(f'Training Confusion Matrix - Epoch {epoch} (Resumed)')
plt.xlabel('Predicted Label')
plt.ylabel('True Label')
accuracy = np.trace(cm) / np.sum(cm)
plt.text(0.5, -0.15, f'Training Accuracy: {accuracy:.4f}',
ha='center', transform=plt.gca().transAxes)
plt.tight_layout()
save_path = os.path.join(self.output_dir, f'training_confusion_matrix_epoch_{epoch}_resumed.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f" 💾 Training confusion matrix saved: {save_path}")
def plot_training_curves(loss_tracker, output_dir):
"""绘制训练损失曲线"""
plt.figure(figsize=(12, 8))
if loss_tracker.train_losses:
plt.subplot(2, 1, 1)
plt.plot(loss_tracker.train_steps, loss_tracker.train_losses,
'b-', label='Training Loss (Resumed)', linewidth=2, alpha=0.8)
plt.title('Training Loss Curve (Resumed from Checkpoint)', fontsize=14, fontweight='bold')
plt.xlabel('Training Steps')
plt.ylabel('Loss Value')
plt.legend()
plt.grid(True, alpha=0.3)
if len(loss_tracker.train_losses) > 10:
z = np.polyfit(loss_tracker.train_steps, loss_tracker.train_losses, 1)
p = np.poly1d(z)
plt.plot(loss_tracker.train_steps, p(loss_tracker.train_steps),
'r--', alpha=0.6, label='Trend Line')
plt.legend()
if loss_tracker.eval_losses:
plt.subplot(2, 1, 2)
plt.plot(loss_tracker.eval_steps, loss_tracker.eval_losses,
'g-', label='Validation Loss (Resumed)', linewidth=2, alpha=0.8)
plt.title('Validation Loss Curve (Resumed from Checkpoint)', fontsize=14, fontweight='bold')
plt.xlabel('Training Steps')
plt.ylabel('Loss Value')
plt.legend()
plt.grid(True, alpha=0.3)
if loss_tracker.train_losses and loss_tracker.eval_losses:
plt.figure(figsize=(12, 6))
min_len = min(len(loss_tracker.train_losses), len(loss_tracker.eval_losses))
train_steps_aligned = loss_tracker.train_steps[:min_len]
train_losses_aligned = loss_tracker.train_losses[:min_len]
eval_steps_aligned = loss_tracker.eval_steps[:min_len]
eval_losses_aligned = loss_tracker.eval_losses[:min_len]
plt.plot(train_steps_aligned, train_losses_aligned,
'b-', label='Training Loss (Resumed)', linewidth=2, alpha=0.8)
plt.plot(eval_steps_aligned, eval_losses_aligned,
'r-', label='Validation Loss (Resumed)', linewidth=2, alpha=0.8)
plt.title('Training vs Validation Loss Comparison (Resumed from Checkpoint)', fontsize=16, fontweight='bold')
plt.xlabel('Training Steps', fontsize=12)
plt.ylabel('Loss Value', fontsize=12)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
if len(eval_losses_aligned) > 20:
recent_train = np.mean(train_losses_aligned[-10:])
recent_eval = np.mean(eval_losses_aligned[-10:])
if recent_eval > recent_train * 1.2:
plt.text(0.7, 0.9, ' Possible Overfitting', transform=plt.gca().transAxes,
bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7))
plt.tight_layout()
compare_path = os.path.join(output_dir, 'loss_comparison_curves_resumed.png')
plt.savefig(compare_path, dpi=300, bbox_inches='tight')
logger.info(f"📈 Training comparison curves saved: {compare_path}")
plt.tight_layout()
curves_path = os.path.join(output_dir, 'training_curves_resumed.png')
plt.savefig(curves_path, dpi=300, bbox_inches='tight')
plt.close()
logger.info(f"📈 Training curves saved: {curves_path}")
class SentencePairDataset(Dataset):
"""句子对数据集类(支持加权采样)"""
def __init__(self, data, tokenizer, max_length=512, is_validation=False):
self.data = data
self.tokenizer = tokenizer
self.max_length = max_length
self.is_validation = is_validation
self.valid_data = [item for item in data if item['label'] in [0, 1]]
dataset_type = "验证" if is_validation else "训练"
logger.info(f"原始{dataset_type}数据: {len(data)} 条,有效数据: {len(self.valid_data)}")
self.sentence1_list = [item['sentence1'] for item in self.valid_data]
self.sentence2_list = [item['sentence2'] for item in self.valid_data]
self.labels = [item['label'] for item in self.valid_data]
# 只为训练集计算权重和采样器
if not is_validation:
self.class_counts = Counter(self.labels)
self.class_weights = self._compute_class_weights()
self.sample_weights = self._compute_sample_weights()
def _compute_class_weights(self):
"""计算类别权重"""
total_samples = len(self.labels)
class_weights = {}
for label in [0, 1]:
count = self.class_counts[label]
class_weights[label] = total_samples / (2 * count)
return class_weights
def _compute_sample_weights(self):
"""计算每个样本的权重"""
sample_weights = []
for label in self.labels:
sample_weights.append(self.class_weights[label])
return torch.tensor(sample_weights, dtype=torch.float)
def get_weighted_sampler(self):
"""返回加权随机采样器(仅训练集)"""
if self.is_validation:
raise ValueError("验证集不需要加权采样器")
return WeightedRandomSampler(
weights=self.sample_weights,
num_samples=len(self.sample_weights),
replacement=True
)
def __len__(self):
return len(self.valid_data)
def __getitem__(self, idx):
sentence1 = str(self.sentence1_list[idx])
sentence2 = str(self.sentence2_list[idx])
label = self.labels[idx]
encoding = self.tokenizer(
sentence1,
sentence2,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
def load_and_split_data(train_file, validation_split=0.2, random_state=42):
"""加载数据并划分训练集和验证集"""
try:
with open(train_file, 'r', encoding='utf-8') as f:
all_data = json.load(f)
logger.info(f"成功加载原始数据: {len(all_data)} 条记录")
# 过滤有效数据
valid_data = [item for item in all_data if item['label'] in [0, 1]]
logger.info(f"有效数据: {len(valid_data)} 条记录")
# 按标签分层划分
labels = [item['label'] for item in valid_data]
train_data, val_data = train_test_split(
valid_data,
test_size=validation_split,
random_state=random_state,
stratify=labels
)
logger.info(f"数据划分完成:")
logger.info(f" 🔹 训练集: {len(train_data)}")
logger.info(f" 🔹 验证集: {len(val_data)}")
logger.info(f" 🔹 验证集比例: {validation_split*100:.1f}%")
# 分析训练集和验证集的分布
train_labels = [item['label'] for item in train_data]
val_labels = [item['label'] for item in val_data]
train_counts = Counter(train_labels)
val_counts = Counter(val_labels)
logger.info(f"训练集分布: 标签0={train_counts[0]}({train_counts[0]/len(train_data)*100:.1f}%), 标签1={train_counts[1]}({train_counts[1]/len(train_data)*100:.1f}%)")
logger.info(f"验证集分布: 标签0={val_counts[0]}({val_counts[0]/len(val_data)*100:.1f}%), 标签1={val_counts[1]}({val_counts[1]/len(val_data)*100:.1f}%)")
return train_data, val_data
except Exception as e:
logger.error(f"加载和划分数据失败: {str(e)}")
return None, None
def analyze_data_distribution(data):
"""分析数据分布并计算优化的Focal Loss参数"""
valid_data = [item for item in data if item['label'] in [0, 1]]
label_counts = {}
for item in valid_data:
label = item['label']
label_counts[label] = label_counts.get(label, 0) + 1
total_samples = len(valid_data)
logger.info("=== 训练数据分布分析 ===")
logger.info(f"总有效记录数: {total_samples}")
class_proportions = {}
alpha_weights = []
for label in [0, 1]:
count = label_counts.get(label, 0)
proportion = count / total_samples
class_proportions[label] = proportion
label_name = "同段落" if label == 0 else "不同段落"
logger.info(f"标签 {label} ({label_name}): {count} 条 ({proportion * 100:.2f}%)")
minority_ratio = min(class_proportions.values())
imbalance_ratio = max(class_proportions.values()) / minority_ratio
logger.info(f"\n📊 数据不平衡分析:")
logger.info(f" 🔹 少数类比例: {minority_ratio * 100:.2f}%")
logger.info(f" 🔹 不平衡比率: {imbalance_ratio:.2f}:1")
if imbalance_ratio > 5:
alpha_weights = [0.1, 0.9]
logger.info(" 🎯 使用激进的alpha权重设置")
else:
alpha_weights = [1.0 - class_proportions[0], 1.0 - class_proportions[1]]
if imbalance_ratio > 6:
recommended_gamma = 3.5
logger.info(" 严重不平衡 - 使用Gamma=3.5强化聚焦")
elif imbalance_ratio > 4:
recommended_gamma = 3.0
logger.info(" 中度偏严重不平衡 - 使用Gamma=3.0")
else:
recommended_gamma = 2.5
logger.info(f"\n🎯 优化的Focal Loss参数设置:")
logger.info(f" 🔹 Alpha权重: [多数类={alpha_weights[0]:.3f}, 少数类={alpha_weights[1]:.3f}]")
logger.info(f" 🔹 优化Gamma: {recommended_gamma} (增强难样本聚焦)")
logger.info(f" 🔹 公式: FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)")
logger.info(f" 🔹 加权采样: 启用WeightedRandomSampler")
return label_counts, alpha_weights, recommended_gamma
def compute_metrics(eval_pred):
"""计算评估指标"""
predictions, labels = eval_pred
# 处理predictions可能是嵌套列表或不规则数组的问题
if isinstance(predictions, (list, tuple)):
# 如果是列表或元组,取第一个元素(通常是logits)
predictions = predictions[0]
# 确保predictions是numpy数组
if not isinstance(predictions, np.ndarray):
predictions = np.array(predictions)
# 检查predictions的形状
if len(predictions.shape) == 3:
# 如果是3D数组,取最后一个维度
predictions = predictions[:, -1, :]
elif len(predictions.shape) == 1:
# 如果是1D数组,可能需要reshape
predictions = predictions.reshape(-1, 2)
# 确保我们有正确的2D形状 [batch_size, num_classes]
if len(predictions.shape) != 2:
logger.warning(f"Unexpected predictions shape: {predictions.shape}")
# 尝试flatten并reshape
predictions = predictions.reshape(-1, 2)
# 应用argmax获取预测类别
try:
predictions = np.argmax(predictions, axis=1)
except Exception as e:
logger.error(f"Error in argmax: {e}")
logger.error(f"Predictions shape: {predictions.shape}")
logger.error(f"Predictions dtype: {predictions.dtype}")
# 如果还是失败,使用更安全的方法
predictions = np.array([np.argmax(pred) if len(pred) > 1 else 0 for pred in predictions])
accuracy = accuracy_score(labels, predictions)
return {
'accuracy': accuracy,
}
class FocalLoss(nn.Module):
"""优化的Focal Loss用于处理类别不平衡问题"""
def __init__(self, alpha=None, gamma=3.0, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
if self.alpha.type() != inputs.data.type():
self.alpha = self.alpha.type_as(inputs.data)
at = self.alpha.gather(0, targets.data.view(-1))
ce_loss = ce_loss * at
focal_weight = (1 - pt) ** self.gamma
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
class ScaledDotProductAttention(nn.Module):
"""缩放点积注意力机制"""
def __init__(self, d_model, dropout=0.1):
super(ScaledDotProductAttention, self).__init__()
self.d_model = d_model
self.dropout = nn.Dropout(dropout)
def forward(self, query, key, value, mask=None):
batch_size, seq_len, d_model = query.size()
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model)
if mask is not None:
mask_value = torch.finfo(scores.dtype).min
scores = scores.masked_fill(mask == 0, mask_value)
attention_weights = F.softmax(scores, dim=-1)
attention_weights = self.dropout(attention_weights)
output = torch.matmul(attention_weights, value)
return output, attention_weights
class RoBERTaWithScaledAttentionAndFocalLoss(nn.Module):
"""带缩放点积注意力和优化Focal Loss的RoBERTa模型"""
def __init__(self, model_path, num_labels=2, dropout=0.1,
focal_alpha=None, focal_gamma=3.0):
super(RoBERTaWithScaledAttentionAndFocalLoss, self).__init__()
self.roberta = BertModel.from_pretrained(model_path)
self.config = self.roberta.config
self.config.num_labels = num_labels
self.scaled_attention = ScaledDotProductAttention(
d_model=self.config.hidden_size,
dropout=dropout
)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(self.config.hidden_size, num_labels)
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma)
self._init_weights()
self.focal_alpha = focal_alpha
self.focal_gamma = focal_gamma
def _init_weights(self):
"""初始化新增层的权重"""
nn.init.normal_(self.classifier.weight, std=0.02)
nn.init.zeros_(self.classifier.bias)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None):
roberta_outputs = self.roberta(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
return_dict=True
)
sequence_output = roberta_outputs.last_hidden_state
enhanced_output, attention_weights = self.scaled_attention(
query=sequence_output,
key=sequence_output,
value=sequence_output,
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None
)
cls_output = enhanced_output[:, 0, :]
cls_output = self.dropout(cls_output)
logits = self.classifier(cls_output)
loss = None
if labels is not None:
loss = self.focal_loss(logits, labels)
return {
'loss': loss,
'logits': logits,
'hidden_states': enhanced_output,
'attention_weights': attention_weights
}
def save_pretrained(self, save_directory):
"""保存模型"""
os.makedirs(save_directory, exist_ok=True)
model_to_save = self.module if hasattr(self, 'module') else self
torch.save(model_to_save.state_dict(), os.path.join(save_directory, 'pytorch_model.bin'))
config_dict = {
'model_type': 'RoBERTaWithScaledAttentionAndFocalLoss',
'base_model': 'chinese-roberta-wwm-ext',
'num_labels': self.config.num_labels,
'hidden_size': self.config.hidden_size,
'focal_alpha': self.focal_alpha.tolist() if self.focal_alpha is not None else None,
'focal_gamma': self.focal_gamma,
'has_scaled_attention': True,
'has_focal_loss': True,
'optimization_level': 'high_priority_v100_with_validation_resumed'
}
with open(os.path.join(save_directory, 'config.json'), 'w', encoding='utf-8') as f:
json.dump(config_dict, f, ensure_ascii=False, indent=2)
class WeightedTrainer(Trainer):
"""自定义Trainer支持WeightedRandomSampler"""
def __init__(self, weighted_sampler=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.weighted_sampler = weighted_sampler
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
if self.weighted_sampler is not None:
train_sampler = self.weighted_sampler
else:
train_sampler = self._get_train_sampler()
return DataLoader(
train_dataset,
batch_size=self.args.train_batch_size,
sampler=train_sampler,
collate_fn=self.data_collator,
drop_last=self.args.dataloader_drop_last,
num_workers=self.args.dataloader_num_workers,
pin_memory=self.args.dataloader_pin_memory,
)
def resume_training_from_checkpoint(train_data, val_data,
checkpoint_path="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result/checkpoint-20240",
model_path="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model",
output_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train",
checkpoint_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result"):
"""从checkpoint恢复训练优化的RoBERTa模型"""
gpu_available, gpu_memory = check_gpu_availability()
device = torch.device('cuda')
logger.info(f"🚀 使用GPU设备: {device}")
# 检查checkpoint是否存在
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"❌ Checkpoint路径不存在: {checkpoint_path}")
logger.info(f"📂 从checkpoint恢复训练: {checkpoint_path}")
# 数据分布分析和优化的Focal Loss参数计算
label_distribution, alpha_weights, recommended_gamma = analyze_data_distribution(train_data)
alpha_tensor = torch.tensor(alpha_weights, dtype=torch.float).to(device)
logger.info(f"📥 加载Chinese-RoBERTa-WWM-Ext模型: {model_path}")
tokenizer = BertTokenizer.from_pretrained(model_path)
model = RoBERTaWithScaledAttentionAndFocalLoss(
model_path=model_path,
num_labels=2,
dropout=0.1,
focal_alpha=alpha_tensor,
focal_gamma=recommended_gamma
)
model = model.to(device)
logger.info(f"✅ 模型已加载到GPU: {next(model.parameters()).device}")
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"📊 模型参数统计:")
logger.info(f" 🔹 总参数量: {total_params:,}")
logger.info(f" 🔹 可训练参数: {trainable_params:,}")
logger.info("📋 准备训练数据集和验证数据集...")
train_dataset = SentencePairDataset(train_data, tokenizer, max_length=512, is_validation=False)
val_dataset = SentencePairDataset(val_data, tokenizer, max_length=512, is_validation=True)
weighted_sampler = train_dataset.get_weighted_sampler()
logger.info(f" 🔹 训练集大小: {len(train_dataset)}")
logger.info(f" 🔹 验证集大小: {len(val_dataset)}")
logger.info(f" 🔹 类别权重: {train_dataset.class_weights}")
# V100 48GB优化配置(保持原有配置)
batch_size = 16
gradient_accumulation = 2
max_grad_norm = 1.0
fp16 = True
dataloader_num_workers = 4
effective_batch_size = batch_size * gradient_accumulation
initial_learning_rate = 2e-5
warmup_ratio = 0.15
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
training_args = TrainingArguments(
output_dir=checkpoint_dir,
num_train_epochs=100, # 继续训练到100个epoch
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=gradient_accumulation,
eval_strategy="epoch",
eval_steps=1,
save_strategy="epoch",
save_steps=20,
logging_strategy="steps",
logging_steps=50,
warmup_ratio=warmup_ratio,
weight_decay=0.01,
learning_rate=initial_learning_rate,
load_best_model_at_end=False,
remove_unused_columns=False,
dataloader_pin_memory=True,
fp16=fp16,
dataloader_num_workers=dataloader_num_workers,
group_by_length=True,
report_to=[],
adam_epsilon=1e-8,
max_grad_norm=max_grad_norm,
save_total_limit=5,
skip_memory_metrics=True,
disable_tqdm=False,
lr_scheduler_type="cosine",
warmup_steps=0,
metric_for_best_model="eval_accuracy",
greater_is_better=True,
resume_from_checkpoint=checkpoint_path, # 关键:指定从checkpoint恢复
)
logger.info(f"🔄 从checkpoint恢复训练参数:")
logger.info(f" 🔹 Checkpoint路径: {checkpoint_path}")
logger.info(f" 🔹 目标训练轮数: {training_args.num_train_epochs}")
logger.info(f" 🔹 批次大小: {batch_size}")
logger.info(f" 🔹 梯度累积: {gradient_accumulation}")
logger.info(f" 🔹 有效批次大小: {effective_batch_size}")
logger.info(f" 🔹 学习率: {training_args.learning_rate}")
logger.info(f" 🔹 预热比例: {warmup_ratio}")
logger.info(f" 🔹 序列长度: 512")
logger.info(f" 🔹 混合精度: {fp16}")
logger.info(f" 🔹 验证策略: 每个epoch评估")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
loss_tracker = LossTracker()
# 验证集混淆矩阵回调(每10个epoch)
val_confusion_matrix_callback = ValidationConfusionMatrixCallback(
eval_dataset=val_dataset,
tokenizer=tokenizer,
output_dir=checkpoint_dir,
epochs_interval=10
)
# 训练集混淆矩阵回调(每20个epoch)
train_confusion_matrix_callback = TrainingConfusionMatrixCallback(
train_dataset=train_dataset,
tokenizer=tokenizer,
output_dir=checkpoint_dir,
epochs_interval=20
)
trainer = WeightedTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[loss_tracker, val_confusion_matrix_callback, train_confusion_matrix_callback],
weighted_sampler=weighted_sampler
)
logger.info("🔄 开始从checkpoint恢复训练...")
logger.info("🎯 恢复训练配置:")
logger.info(" ✅ Focal Loss Gamma: 3.0-3.5")
logger.info(" ✅ Alpha权重: [0.1, 0.9]")
logger.info(" ✅ 学习率: 2e-5")
logger.info(" ✅ 预热比例: 15%")
logger.info(" ✅ WeightedRandomSampler")
logger.info(" ✅ 余弦退火学习率调度")
logger.info(" ✅ 验证集: 每个epoch评估")
logger.info(" ✅ 验证集混淆矩阵: 每10个epoch生成 (标记为resumed)")
logger.info(" ✅ 训练集混淆矩阵: 每20个epoch生成 (标记为resumed)")
start_time = datetime.now()
try:
# 从checkpoint恢复训练
trainer.train(resume_from_checkpoint=checkpoint_path)
except RuntimeError as e:
if "out of memory" in str(e).lower():
logger.error("❌ GPU内存不足!")
logger.error("💡 建议减小批次大小")
raise
else:
raise
end_time = datetime.now()
training_duration = end_time - start_time
logger.info(f"🎉 从checkpoint恢复训练完成! 耗时: {training_duration}")
logger.info("📈 生成训练可视化图表...")
plot_training_curves(loss_tracker, checkpoint_dir)
logger.info(f"💾 保存最终模型到: {output_dir}")
# 保存到指定的模型输出目录
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
# 保存损失历史到checkpoints目录(标记为resumed)
loss_history = {
'train_losses': loss_tracker.train_losses,
'train_steps': loss_tracker.train_steps,
'eval_losses': loss_tracker.eval_losses,
'eval_steps': loss_tracker.eval_steps,
'resumed_from_checkpoint': checkpoint_path,
'resume_time': start_time.isoformat()
}
with open(os.path.join(checkpoint_dir, 'loss_history_resumed.json'), 'w', encoding='utf-8') as f:
json.dump(loss_history, f, ensure_ascii=False, indent=2)
# 保存验证集混淆矩阵历史(resumed)
val_cm_history = {epoch: cm.tolist() for epoch, cm in val_confusion_matrix_callback.confusion_matrices.items()}
with open(os.path.join(checkpoint_dir, 'validation_confusion_matrix_history_resumed.json'), 'w', encoding='utf-8') as f:
json.dump(val_cm_history, f, ensure_ascii=False, indent=2)
# 保存训练集混淆矩阵历史(resumed)
train_cm_history = {epoch: cm.tolist() for epoch, cm in train_confusion_matrix_callback.confusion_matrices.items()}
with open(os.path.join(checkpoint_dir, 'training_confusion_matrix_history_resumed.json'), 'w', encoding='utf-8') as f:
json.dump(train_cm_history, f, ensure_ascii=False, indent=2)
# 保存详细的恢复训练信息
training_info = {
"model_name": model_path,
"model_type": "Chinese-RoBERTa-WWM-Ext with Optimized Focal Loss and Weighted Sampling",
"optimization_level": "high_priority_v100_48gb_with_validation_resumed",
"resumed_from_checkpoint": checkpoint_path,
"resume_training_duration": str(training_duration),
"resume_start_time": start_time.isoformat(),
"resume_end_time": end_time.isoformat(),
"num_train_samples": len(train_dataset),
"num_val_samples": len(val_dataset),
"validation_split": len(val_dataset) / (len(train_dataset) + len(val_dataset)),
"label_distribution": label_distribution,
"data_imbalance": {
"class_0_count": label_distribution.get(0, 0),
"class_1_count": label_distribution.get(1, 0),
"class_0_ratio": label_distribution.get(0, 0) / len(train_dataset),
"class_1_ratio": label_distribution.get(1, 0) / len(train_dataset),
"imbalance_ratio": label_distribution.get(0, 1) / label_distribution.get(1, 1)
},
"optimized_focal_loss_params": {
"alpha_weights": alpha_weights,
"gamma": recommended_gamma,
"formula": "FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)",
"optimization": "aggressive_minority_class_focus"
},
"weighted_sampling": {
"enabled": True,
"class_weights": train_dataset.class_weights,
"sampler_type": "WeightedRandomSampler",
"applies_to": "training_set_only"
},
"validation_setup": {
"enabled": True,
"validation_split": "20%",
"stratified_split": True,
"eval_strategy": "every_epoch",
"confusion_matrix_frequency": "every_10_epochs"
},
"optimized_learning_strategy": {
"initial_learning_rate": initial_learning_rate,
"warmup_ratio": warmup_ratio,
"lr_scheduler": "cosine",
"improvement": "optimized_for_v100"
},
"gpu_optimization": {
"gpu_name": torch.cuda.get_device_name(0),
"gpu_memory_gb": gpu_memory,
"optimization_target": "V100_48GB",
"effective_batch_size": effective_batch_size,
"sequence_length": 512,
"batch_size_optimization": "v100_optimized"
},
"training_args": {
"num_train_epochs": training_args.num_train_epochs,
"per_device_train_batch_size": training_args.per_device_train_batch_size,
"per_device_eval_batch_size": training_args.per_device_eval_batch_size,
"gradient_accumulation_steps": training_args.gradient_accumulation_steps,
"learning_rate": training_args.learning_rate,
"warmup_ratio": training_args.warmup_ratio,
"weight_decay": training_args.weight_decay,
"fp16": training_args.fp16,
"lr_scheduler_type": training_args.lr_scheduler_type,
"eval_strategy": training_args.eval_strategy,
"resume_from_checkpoint": training_args.resume_from_checkpoint
},
"model_parameters": {
"total_params": total_params,
"trainable_params": trainable_params,
},
"paths": {
"model_input_path": model_path,
"model_output_path": output_dir,
"checkpoint_output_path": checkpoint_dir,
"resume_checkpoint_path": checkpoint_path,
"data_path": "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data"
},
"high_priority_optimizations": [
"Focal Loss Gamma increased to 3.0-3.5",
"Alpha weights set to [0.1, 0.9] for aggressive minority class focus",
"Learning rate optimized for V100: 2e-5",
"Warmup ratio increased to 15%",
"WeightedRandomSampler for balanced class sampling",
"Cosine annealing learning rate scheduler",
"V100 48GB optimized batch size: 16",
"Full sequence length: 512 tokens",
"Validation set with stratified split",
"Validation confusion matrix every 10 epochs",
"Training resumed from checkpoint-20240"
],
"visualization_files": {
"training_curves": "training_curves_resumed.png",
"loss_comparison": "loss_comparison_curves_resumed.png",
"validation_confusion_matrices": [f"validation_confusion_matrix_epoch_{i}_resumed.png" for i in range(10, 101, 10)],
"training_confusion_matrices": [f"training_confusion_matrix_epoch_{i}_resumed.png" for i in range(20, 101, 20)],
"loss_history": "loss_history_resumed.json",
"validation_confusion_matrix_history": "validation_confusion_matrix_history_resumed.json",
"training_confusion_matrix_history": "training_confusion_matrix_history_resumed.json"
},
"training_completed": end_time.isoformat()
}
with open(os.path.join(checkpoint_dir, 'training_info_resumed.json'), 'w', encoding='utf-8') as f:
json.dump(training_info, f, ensure_ascii=False, indent=2)
# 同时在模型目录保存一份配置信息
with open(os.path.join(output_dir, 'training_summary_resumed.json'), 'w', encoding='utf-8') as f:
json.dump(training_info, f, ensure_ascii=False, indent=2)
logger.info("📋 恢复训练信息已保存")
return trainer, model, tokenizer, loss_tracker, val_confusion_matrix_callback, train_confusion_matrix_callback
def main():
"""主函数"""
logger.info("=" * 120)
logger.info("🔄 Chinese-RoBERTa-WWM-Ext V100 48GB从Checkpoint恢复训练")
logger.info("=" * 120)
# 配置路径
train_file = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data/train_dataset.json"
model_path = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model"
output_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train"
checkpoint_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result"
resume_checkpoint = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result/checkpoint-20240"
# 确保所有输出目录存在
os.makedirs(output_dir, exist_ok=True)
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"📁 确保输出目录存在:")
logger.info(f" 🔹 模型输出: {output_dir}")
logger.info(f" 🔹 训练记录: {checkpoint_dir}")
# 检查checkpoint是否存在
if not os.path.exists(resume_checkpoint):
logger.error(f"❌ Checkpoint不存在: {resume_checkpoint}")
logger.info("💡 请检查checkpoint路径是否正确")
return
# 确认第三方报告工具已禁用
logger.info("🚫 确认第三方报告工具状态:")
logger.info(f" 🔹 WANDB_DISABLED: {os.environ.get('WANDB_DISABLED', 'not set')}")
logger.info(f" 🔹 COMET_MODE: {os.environ.get('COMET_MODE', 'not set')}")
logger.info(f" 🔹 NEPTUNE_MODE: {os.environ.get('NEPTUNE_MODE', 'not set')}")
logger.info(f" ✅ 所有第三方报告工具已禁用")
logger.info(f"\n📋 V100 48GB从Checkpoint恢复配置:")
logger.info(f" 🔹 训练数据: {train_file}")
logger.info(f" 🔹 基础模型: {model_path}")
logger.info(f" 🔹 恢复checkpoint: {resume_checkpoint}")
logger.info(f" 🔹 模型类型: Chinese-RoBERTa-WWM-Ext")
logger.info(f" 🔹 优化等级: V100 48GB高性能优化")
logger.info(f" 🔹 验证集: 20%分层划分")
logger.info(f" 🔹 目标: 处理严重数据不平衡问题")
logger.info(f" 🔹 核心优化:")
logger.info(f" • Focal Loss Gamma: 3.0+ (增强难样本聚焦)")
logger.info(f" • Alpha权重: [0.1, 0.9] (激进的少数类关注)")
logger.info(f" • 学习率: 2e-5 (V100优化)")
logger.info(f" • 批次大小: 16 (V100大显存优化)")
logger.info(f" • 序列长度: 512 (完整长度)")
logger.info(f" • WeightedRandomSampler (平衡采样)")
logger.info(f" • 验证集每个epoch评估")
logger.info(f" • 验证集混淆矩阵每10个epoch生成")
logger.info(f" 🔹 目标训练轮数: 100 epochs")
logger.info(f" 🔹 模型输出: {output_dir}")
logger.info(f" 🔹 训练记录: {checkpoint_dir}")
# 加载和划分数据(使用相同的随机种子确保一致性)
train_data, val_data = load_and_split_data(train_file, validation_split=0.2, random_state=42)
if train_data is None or val_data is None:
logger.error("❌ 无法加载和划分数据,程序退出")
return
try:
# 从checkpoint恢复训练
trainer, model, tokenizer, loss_tracker, val_cm_callback, train_cm_callback = resume_training_from_checkpoint(
train_data, val_data,
checkpoint_path=resume_checkpoint,
model_path=model_path,
output_dir=output_dir,
checkpoint_dir=checkpoint_dir
)
logger.info("=" * 120)
logger.info("🎉 V100 48GB从Checkpoint恢复训练完成!")
logger.info("=" * 120)
logger.info(f"📁 文件输出位置:")
logger.info(f" 🔹 训练好的模型: {output_dir}")
logger.info(f" 🔹 训练记录和图表: {checkpoint_dir}")
logger.info("📄 生成的文件(恢复训练):")
logger.info(" 模型文件 (model_train目录):")
logger.info(" • pytorch_model.bin - 恢复训练的模型权重")
logger.info(" • config.json - 优化模型配置")
logger.info(" • tokenizer配置文件")
logger.info(" • training_summary_resumed.json - 恢复训练摘要")
logger.info(" 训练记录 (ouput_result目录):")
logger.info(" • training_info_resumed.json - 详细恢复训练信息")
logger.info(" • loss_history_resumed.json - 恢复训练的损失历史")
logger.info(" • validation_confusion_matrix_history_resumed.json - 验证集混淆矩阵历史")
logger.info(" • training_confusion_matrix_history_resumed.json - 训练集混淆矩阵历史")
logger.info(" • training_curves_resumed.png - 恢复训练损失曲线")
logger.info(" • loss_comparison_curves_resumed.png - 训练vs验证损失对比")
logger.info(" • validation_confusion_matrix_epoch_X_resumed.png - 验证集混淆矩阵")
logger.info(" • training_confusion_matrix_epoch_X_resumed.png - 训练集混淆矩阵")
logger.info(" • checkpoint-* - 新的训练检查点")
logger.info("🔥 V100 48GB恢复训练特性:")
logger.info(" ✅ 从checkpoint-20240成功恢复")
logger.info(" ✅ 保持所有原有优化参数")
logger.info(" ✅ Chinese-RoBERTa-WWM-Ext 基础模型")
logger.info(" ✅ 激进的Focal Loss参数 (Gamma=3.0+, Alpha=[0.1,0.9])")
logger.info(" ✅ V100优化学习率: 2e-5")
logger.info(" ✅ 大批次训练: 16 (有效批次: 32)")
logger.info(" ✅ 完整序列长度: 512 tokens")
logger.info(" ✅ WeightedRandomSampler 平衡采样")
logger.info(" ✅ 余弦退火学习率调度")
logger.info(" ✅ 缩放点积注意力机制")
logger.info(" ✅ 验证集分层划分 (20%)")
logger.info(" ✅ 每个epoch验证评估")
logger.info(" ✅ 验证集混淆矩阵每10个epoch (标记resumed)")
logger.info(" ✅ 训练集混淆矩阵每20个epoch (标记resumed)")
logger.info(" ✅ 继续训练到100 epochs")
logger.info(" ✅ 完整可视化监控")
logger.info("🎯 从checkpoint-20240恢复的优势:")
logger.info(" ⚡ 保留之前的训练进展")
logger.info(" ⚡ 继承已学习的模型权重")
logger.info(" ⚡ 维持原有的优化器状态")
logger.info(" ⚡ 保持学习率调度进度")
logger.info(" ⚡ 节省重新训练时间")
logger.info(" ⚡ 无缝继续训练流程")
# 显示完整保存路径列表
logger.info(f"\n📂 文件保存详情:")
logger.info(f"📋 模型文件 ({output_dir}):")
try:
for file in os.listdir(output_dir):
file_path = os.path.join(output_dir, file)
if os.path.isfile(file_path):
file_size = os.path.getsize(file_path) / (1024 * 1024)
logger.info(f" 📄 {file} ({file_size:.2f} MB)")
except Exception as e:
logger.warning(f" 无法列出模型文件: {str(e)}")
logger.info(f"📋 训练记录 ({checkpoint_dir}):")
try:
files = os.listdir(checkpoint_dir)
# 按类型分组显示,优先显示resumed文件
resumed_files = [f for f in files if 'resumed' in f]
other_files = [f for f in files if 'resumed' not in f]
if resumed_files:
logger.info(" 恢复训练文件:")
for file in sorted(resumed_files):
file_path = os.path.join(checkpoint_dir, file)
if os.path.isfile(file_path):
file_size = os.path.getsize(file_path) / 1024
if file.endswith('.json'):
logger.info(f" 📄 {file} ({file_size:.1f} KB)")
elif file.endswith('.png'):
logger.info(f" 📊 {file} ({file_size:.1f} KB)")
else:
file_size_mb = file_size / 1024
logger.info(f" 📄 {file} ({file_size_mb:.2f} MB)")
checkpoint_dirs = [f for f in other_files if f.startswith('checkpoint-')]
if checkpoint_dirs:
logger.info(" 训练检查点:")
for dir_name in sorted(checkpoint_dirs):
logger.info(f" 📁 {dir_name}/")
except Exception as e:
logger.warning(f" 无法列出训练记录: {str(e)}")
logger.info("\n🎯 恢复训练完成,可以开始评估模型性能!")
logger.info("📊 建议关注:")
logger.info(" • 恢复训练后的损失曲线变化")
logger.info(" • 验证集混淆矩阵的改进趋势")
logger.info(" • 训练损失vs验证损失的收敛情况")
logger.info(" • 少数类召回率和精确率的最终表现")
except Exception as e:
logger.error(f"❌ 从Checkpoint恢复训练过程中出现错误: {str(e)}")
import traceback
traceback.print_exc()
raise
if __name__ == "__main__":
main()