import json import torch import torch.nn as nn import torch.nn.functional as F import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score, precision_score, \ recall_score from transformers import ( BertTokenizer, BertForSequenceClassification, BertModel, BertConfig, TrainingArguments, Trainer, DataCollatorWithPadding, TrainerCallback, EarlyStoppingCallback ) from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler, random_split import logging import os from datetime import datetime import math from collections import defaultdict, Counter # 禁用wandb和其他第三方报告工具 os.environ["WANDB_DISABLED"] = "true" os.environ["COMET_MODE"] = "disabled" os.environ["NEPTUNE_MODE"] = "disabled" # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 设置matplotlib中文字体 plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False def check_gpu_availability(): """检查GPU可用性""" if not torch.cuda.is_available(): raise RuntimeError("❌ GPU不可用!此脚本需要GPU支持。") gpu_count = torch.cuda.device_count() gpu_name = torch.cuda.get_device_name(0) gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 logger.info(f"✅ GPU检查通过!") logger.info(f" 🔹 可用GPU数量: {gpu_count}") logger.info(f" 🔹 GPU型号: {gpu_name}") logger.info(f" 🔹 GPU内存: {gpu_memory:.1f} GB") torch.cuda.empty_cache() torch.backends.cudnn.benchmark = True return True, gpu_memory class LossTracker(TrainerCallback): """损失跟踪回调类""" def __init__(self): self.train_losses = [] self.eval_losses = [] self.train_steps = [] self.eval_steps = [] self.eval_f1_scores = [] self.eval_epochs = [] self.current_epoch = 0 def on_log(self, args, state, control, logs=None, **kwargs): if logs: if 'loss' in logs: self.train_losses.append(logs['loss']) self.train_steps.append(state.global_step) if 'eval_loss' in logs: self.eval_losses.append(logs['eval_loss']) self.eval_steps.append(state.global_step) if 'eval_f1_macro' in logs: self.eval_f1_scores.append(logs['eval_f1_macro']) self.eval_epochs.append(state.epoch) def on_epoch_end(self, args, state, control, **kwargs): self.current_epoch = state.epoch class ValidationMetricsCallback(TrainerCallback): """验证指标记录回调""" def __init__(self): self.validation_history = [] def on_evaluate(self, args, state, control, model=None, logs=None, **kwargs): if logs: epoch = int(state.epoch) metrics = { 'epoch': epoch, 'eval_loss': logs.get('eval_loss', 0), 'eval_accuracy': logs.get('eval_accuracy', 0), 'eval_f1_minority': logs.get('eval_f1_minority', 0), 'eval_f1_macro': logs.get('eval_f1_macro', 0), 'eval_precision_minority': logs.get('eval_precision_minority', 0), 'eval_recall_minority': logs.get('eval_recall_minority', 0) } self.validation_history.append(metrics) logger.info(f"📊 Epoch {epoch} 验证指标:") logger.info(f" 🔹 验证损失: {metrics['eval_loss']:.4f}") logger.info(f" 🔹 验证准确率: {metrics['eval_accuracy']:.4f}") logger.info(f" 🔹 宏平均F1: {metrics['eval_f1_macro']:.4f}") logger.info(f" 🔹 少数类F1: {metrics['eval_f1_minority']:.4f}") logger.info(f" 🔹 少数类精确率: {metrics['eval_precision_minority']:.4f}") logger.info(f" 🔹 少数类召回率: {metrics['eval_recall_minority']:.4f}") class ConfusionMatrixCallback(TrainerCallback): """混淆矩阵生成回调""" def __init__(self, eval_dataset, tokenizer, output_dir, epochs_interval=10): self.eval_dataset = eval_dataset self.tokenizer = tokenizer self.output_dir = output_dir self.epochs_interval = epochs_interval self.confusion_matrices = {} def on_evaluate(self, args, state, control, model=None, **kwargs): current_epoch = int(state.epoch) if current_epoch % self.epochs_interval == 0: logger.info(f"📊 Generating confusion matrix for epoch {current_epoch}...") model.eval() predictions = [] true_labels = [] device = next(model.parameters()).device with torch.no_grad(): # 限制样本数量以加速评估 eval_size = min(1000, len(self.eval_dataset)) for i in range(eval_size): item = self.eval_dataset[i] input_ids = item['input_ids'].unsqueeze(0).to(device) attention_mask = item['attention_mask'].unsqueeze(0).to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) pred = torch.argmax(outputs['logits'], dim=-1).cpu().item() predictions.append(pred) true_labels.append(item['labels'].item()) cm = confusion_matrix(true_labels, predictions) self.confusion_matrices[current_epoch] = cm self.save_confusion_matrix(cm, current_epoch) model.train() def save_confusion_matrix(self, cm, epoch): """保存混淆矩阵图""" plt.figure(figsize=(8, 6)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)']) plt.title(f'Validation Confusion Matrix - Epoch {epoch}') plt.xlabel('Predicted Label') plt.ylabel('True Label') accuracy = np.trace(cm) / np.sum(cm) plt.text(0.5, -0.15, f'Validation Accuracy: {accuracy:.4f}', ha='center', transform=plt.gca().transAxes) plt.tight_layout() save_path = os.path.join(self.output_dir, f'val_confusion_matrix_epoch_{epoch}.png') plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() logger.info(f" 💾 Validation confusion matrix saved: {save_path}") def plot_training_curves(loss_tracker, validation_metrics, output_dir): """绘制训练曲线和验证指标""" fig, axes = plt.subplots(2, 2, figsize=(15, 12)) # 1. 训练损失曲线 if loss_tracker.train_losses: axes[0, 0].plot(loss_tracker.train_steps, loss_tracker.train_losses, 'b-', label='Training Loss', linewidth=2, alpha=0.8) axes[0, 0].set_title('Training Loss Curve', fontsize=14, fontweight='bold') axes[0, 0].set_xlabel('Training Steps') axes[0, 0].set_ylabel('Loss Value') axes[0, 0].legend() axes[0, 0].grid(True, alpha=0.3) # 2. 验证损失曲线 if loss_tracker.eval_losses: axes[0, 1].plot(loss_tracker.eval_steps, loss_tracker.eval_losses, 'r-', label='Validation Loss', linewidth=2, alpha=0.8) axes[0, 1].set_title('Validation Loss Curve', fontsize=14, fontweight='bold') axes[0, 1].set_xlabel('Training Steps') axes[0, 1].set_ylabel('Loss Value') axes[0, 1].legend() axes[0, 1].grid(True, alpha=0.3) # 3. 宏平均F1分数曲线 if validation_metrics: epochs = [m['epoch'] for m in validation_metrics] f1_scores = [m['eval_f1_macro'] for m in validation_metrics] axes[1, 0].plot(epochs, f1_scores, 'g-', marker='o', label='Macro F1', linewidth=2, alpha=0.8) axes[1, 0].set_title('Macro F1 Score', fontsize=14, fontweight='bold') axes[1, 0].set_xlabel('Epoch') axes[1, 0].set_ylabel('F1 Score') axes[1, 0].legend() axes[1, 0].grid(True, alpha=0.3) # 标记最佳F1分数 best_f1_idx = np.argmax(f1_scores) best_epoch = epochs[best_f1_idx] best_f1 = f1_scores[best_f1_idx] axes[1, 0].annotate(f'Best F1: {best_f1:.4f}\nEpoch: {best_epoch}', xy=(best_epoch, best_f1), xytext=(10, 10), textcoords='offset points', ha='left', bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7), arrowprops=dict(arrowstyle='->', connectionstyle='arc3,rad=0')) # 4. 综合指标对比 if validation_metrics: epochs = [m['epoch'] for m in validation_metrics] accuracy = [m['eval_accuracy'] for m in validation_metrics] f1_minority = [m['eval_f1_minority'] for m in validation_metrics] f1_macro = [m['eval_f1_macro'] for m in validation_metrics] axes[1, 1].plot(epochs, accuracy, 'b-', label='Accuracy', linewidth=2, alpha=0.8) axes[1, 1].plot(epochs, f1_minority, 'r-', label='Minority F1', linewidth=2, alpha=0.8) axes[1, 1].plot(epochs, f1_macro, 'g-', label='Macro F1', linewidth=2, alpha=0.8) axes[1, 1].set_title('Validation Metrics Comparison', fontsize=14, fontweight='bold') axes[1, 1].set_xlabel('Epoch') axes[1, 1].set_ylabel('Score') axes[1, 1].legend() axes[1, 1].grid(True, alpha=0.3) plt.tight_layout() # 保存训练曲线 curves_path = os.path.join(output_dir, 'training_validation_curves.png') plt.savefig(curves_path, dpi=300, bbox_inches='tight') plt.close() logger.info(f"📈 Training and validation curves saved: {curves_path}") class SentencePairDataset(Dataset): """句子对数据集类(支持加权采样)""" def __init__(self, data, tokenizer, max_length=512): self.data = data self.tokenizer = tokenizer self.max_length = max_length self.valid_data = [item for item in data if item['label'] in [0, 1]] logger.info(f"原始数据: {len(data)} 条,有效数据: {len(self.valid_data)} 条") self.sentence1_list = [item['sentence1'] for item in self.valid_data] self.sentence2_list = [item['sentence2'] for item in self.valid_data] self.labels = [item['label'] for item in self.valid_data] self.class_counts = Counter(self.labels) self.class_weights = self._compute_class_weights() self.sample_weights = self._compute_sample_weights() def _compute_class_weights(self): """计算类别权重""" total_samples = len(self.labels) class_weights = {} for label in [0, 1]: count = self.class_counts[label] class_weights[label] = total_samples / (2 * count) return class_weights def _compute_sample_weights(self): """计算每个样本的权重""" sample_weights = [] for label in self.labels: sample_weights.append(self.class_weights[label]) return torch.tensor(sample_weights, dtype=torch.float) def get_weighted_sampler(self): """返回加权随机采样器""" return WeightedRandomSampler( weights=self.sample_weights, num_samples=len(self.sample_weights), replacement=True ) def __len__(self): return len(self.valid_data) def __getitem__(self, idx): sentence1 = str(self.sentence1_list[idx]) sentence2 = str(self.sentence2_list[idx]) label = self.labels[idx] encoding = self.tokenizer( sentence1, sentence2, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) } def load_training_data(train_file): """加载训练数据""" try: with open(train_file, 'r', encoding='utf-8') as f: train_data = json.load(f) logger.info(f"成功加载训练数据: {len(train_data)} 条记录") return train_data except Exception as e: logger.error(f"加载训练数据失败: {str(e)}") return None def analyze_data_distribution(data): """分析数据分布并计算优化的Focal Loss参数""" valid_data = [item for item in data if item['label'] in [0, 1]] label_counts = {} for item in valid_data: label = item['label'] label_counts[label] = label_counts.get(label, 0) + 1 total_samples = len(valid_data) logger.info("=== 训练数据分布分析 ===") logger.info(f"总有效记录数: {total_samples}") class_proportions = {} alpha_weights = [] for label in [0, 1]: count = label_counts.get(label, 0) proportion = count / total_samples class_proportions[label] = proportion label_name = "同段落" if label == 0 else "不同段落" logger.info(f"标签 {label} ({label_name}): {count} 条 ({proportion * 100:.2f}%)") minority_ratio = min(class_proportions.values()) imbalance_ratio = max(class_proportions.values()) / minority_ratio logger.info(f"\n📊 数据不平衡分析:") logger.info(f" 🔹 少数类比例: {minority_ratio * 100:.2f}%") logger.info(f" 🔹 不平衡比率: {imbalance_ratio:.2f}:1") # 相对保守的参数设置,避免过度优化 if imbalance_ratio > 5: alpha_weights = [0.2, 0.8] # 更温和的权重 logger.info(" 🎯 使用平衡的alpha权重设置") else: alpha_weights = [1.0 - class_proportions[0], 1.0 - class_proportions[1]] if imbalance_ratio > 6: recommended_gamma = 2.5 # 降低gamma避免过拟合 logger.info(" ⚠️ 严重不平衡 - 使用Gamma=2.5") elif imbalance_ratio > 4: recommended_gamma = 2.0 logger.info(" ⚠️ 中度偏严重不平衡 - 使用Gamma=2.0") else: recommended_gamma = 1.5 logger.info(f"\n🎯 平衡的Focal Loss参数设置:") logger.info(f" 🔹 Alpha权重: [多数类={alpha_weights[0]:.3f}, 少数类={alpha_weights[1]:.3f}]") logger.info(f" 🔹 平衡Gamma: {recommended_gamma}") logger.info(f" 🔹 公式: FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)") return label_counts, alpha_weights, recommended_gamma def compute_metrics(eval_pred): """计算详细的评估指标 - 修复版本""" predictions, labels = eval_pred # 🎯 关键修复:处理不规则的predictions格式 try: # 如果predictions是嵌套列表或元组,取第一个元素 if isinstance(predictions, (list, tuple)): predictions = predictions[0] # 确保predictions是numpy数组 if not isinstance(predictions, np.ndarray): predictions = np.array(predictions) # 处理多维数组的情况 if len(predictions.shape) > 2: # 如果是3维或更高维,reshape到2维 predictions = predictions.reshape(-1, predictions.shape[-1]) elif len(predictions.shape) == 1: # 如果是1维,检查是否需要处理 if predictions.shape[0] != len(labels): logger.warning(f"预测维度不匹配: predictions={predictions.shape}, labels={len(labels)}") return {'accuracy': 0.0, 'f1_macro': 0.0, 'f1_minority': 0.0, 'f1': 0.0} # 取argmax得到预测类别 predictions = np.argmax(predictions, axis=1) except Exception as e: logger.error(f"处理predictions时出错: {e}") logger.error(f"predictions类型: {type(predictions)}") if hasattr(predictions, 'shape'): logger.error(f"predictions形状: {predictions.shape}") return {'accuracy': 0.0, 'f1_macro': 0.0, 'f1_minority': 0.0, 'f1': 0.0} # 确保labels是一维数组 if hasattr(labels, 'flatten'): labels = labels.flatten() try: # 基本指标 accuracy = accuracy_score(labels, predictions) f1_macro = f1_score(labels, predictions, average='macro', zero_division=0) # 少数类指标(假设1是少数类) f1_minority = f1_score(labels, predictions, pos_label=1, average='binary', zero_division=0) precision_minority = precision_score(labels, predictions, pos_label=1, average='binary', zero_division=0) recall_minority = recall_score(labels, predictions, pos_label=1, average='binary', zero_division=0) # 多数类指标 f1_majority = f1_score(labels, predictions, pos_label=0, average='binary', zero_division=0) precision_majority = precision_score(labels, predictions, pos_label=0, average='binary', zero_division=0) recall_majority = recall_score(labels, predictions, pos_label=0, average='binary', zero_division=0) return { 'accuracy': accuracy, 'f1_macro': f1_macro, 'f1_minority': f1_minority, 'precision_minority': precision_minority, 'recall_minority': recall_minority, 'f1_majority': f1_majority, 'precision_majority': precision_majority, 'recall_majority': recall_majority, 'f1': f1_macro # 🎯 修改:主要用于模型选择的指标改为宏平均F1 } except Exception as e: logger.error(f"计算指标时出错: {e}") return {'accuracy': 0.0, 'f1_macro': 0.0, 'f1_minority': 0.0, 'f1': 0.0} class FocalLoss(nn.Module): """平衡的Focal Loss用于处理类别不平衡问题""" def __init__(self, alpha=None, gamma=2.0, reduction='mean'): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) if self.alpha is not None: if self.alpha.type() != inputs.data.type(): self.alpha = self.alpha.type_as(inputs.data) at = self.alpha.gather(0, targets.data.view(-1)) ce_loss = ce_loss * at focal_weight = (1 - pt) ** self.gamma focal_loss = focal_weight * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss class ScaledDotProductAttention(nn.Module): """缩放点积注意力机制""" def __init__(self, d_model, dropout=0.1): super(ScaledDotProductAttention, self).__init__() self.d_model = d_model self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): batch_size, seq_len, d_model = query.size() scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) if mask is not None: mask_value = torch.finfo(scores.dtype).min scores = scores.masked_fill(mask == 0, mask_value) attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) output = torch.matmul(attention_weights, value) return output, attention_weights class RoBERTaWithScaledAttentionAndFocalLoss(nn.Module): """带缩放点积注意力和平衡Focal Loss的RoBERTa模型""" def __init__(self, model_path, num_labels=2, dropout=0.1, focal_alpha=None, focal_gamma=2.0): super(RoBERTaWithScaledAttentionAndFocalLoss, self).__init__() self.roberta = BertModel.from_pretrained(model_path) self.config = self.roberta.config self.config.num_labels = num_labels self.scaled_attention = ScaledDotProductAttention( d_model=self.config.hidden_size, dropout=dropout ) self.dropout = nn.Dropout(dropout) self.classifier = nn.Linear(self.config.hidden_size, num_labels) self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) self._init_weights() self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma def _init_weights(self): """初始化新增层的权重""" nn.init.normal_(self.classifier.weight, std=0.02) nn.init.zeros_(self.classifier.bias) def forward(self, input_ids, attention_mask=None, token_type_ids=None, labels=None): roberta_outputs = self.roberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True ) sequence_output = roberta_outputs.last_hidden_state enhanced_output, attention_weights = self.scaled_attention( query=sequence_output, key=sequence_output, value=sequence_output, mask=attention_mask.unsqueeze(1) if attention_mask is not None else None ) cls_output = enhanced_output[:, 0, :] cls_output = self.dropout(cls_output) logits = self.classifier(cls_output) loss = None if labels is not None: loss = self.focal_loss(logits, labels) return { 'loss': loss, 'logits': logits, 'hidden_states': enhanced_output, 'attention_weights': attention_weights } def save_pretrained(self, save_directory): """保存模型""" os.makedirs(save_directory, exist_ok=True) model_to_save = self.module if hasattr(self, 'module') else self torch.save(model_to_save.state_dict(), os.path.join(save_directory, 'pytorch_model.bin')) config_dict = { 'model_type': 'RoBERTaWithScaledAttentionAndFocalLoss', 'base_model': 'chinese-roberta-wwm-ext', 'num_labels': self.config.num_labels, 'hidden_size': self.config.hidden_size, 'focal_alpha': self.focal_alpha.tolist() if self.focal_alpha is not None else None, 'focal_gamma': self.focal_gamma, 'has_scaled_attention': True, 'has_focal_loss': True, 'optimization_level': 'scientific_training', 'model_selection': 'macro_f1_based' } with open(os.path.join(save_directory, 'config.json'), 'w', encoding='utf-8') as f: json.dump(config_dict, f, ensure_ascii=False, indent=2) class WeightedTrainer(Trainer): """自定义Trainer支持WeightedRandomSampler""" def __init__(self, weighted_sampler=None, *args, **kwargs): super().__init__(*args, **kwargs) self.weighted_sampler = weighted_sampler def get_train_dataloader(self): if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset if self.weighted_sampler is not None: train_sampler = self.weighted_sampler else: train_sampler = self._get_train_sampler() return DataLoader( train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, ) def train_roberta_model(train_data, model_path="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model", output_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train", checkpoint_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result"): """科学训练RoBERTa模型(基于宏平均F1的模型选择)""" gpu_available, gpu_memory = check_gpu_availability() device = torch.device('cuda') logger.info(f"🚀 使用GPU设备: {device}") # 数据分布分析和平衡的Focal Loss参数计算 label_distribution, alpha_weights, recommended_gamma = analyze_data_distribution(train_data) alpha_tensor = torch.tensor(alpha_weights, dtype=torch.float).to(device) logger.info(f"📥 加载Chinese-RoBERTa-WWM-Ext模型: {model_path}") tokenizer = BertTokenizer.from_pretrained(model_path) model = RoBERTaWithScaledAttentionAndFocalLoss( model_path=model_path, num_labels=2, dropout=0.1, focal_alpha=alpha_tensor, focal_gamma=recommended_gamma ) model = model.to(device) logger.info(f"✅ 模型已加载到GPU: {next(model.parameters()).device}") total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"📊 模型参数统计:") logger.info(f" 🔹 总参数量: {total_params:,}") logger.info(f" 🔹 可训练参数: {trainable_params:,}") # 准备完整数据集 logger.info("📋 准备数据集...") full_dataset = SentencePairDataset(train_data, tokenizer, max_length=512) # 🎯 关键改进:划分训练集和验证集 (80:20) total_size = len(full_dataset) train_size = int(0.8 * total_size) val_size = total_size - train_size # 设置随机种子确保可重复性 torch.manual_seed(42) train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size]) logger.info(f" 🔹 训练集大小: {train_size}") logger.info(f" 🔹 验证集大小: {val_size}") logger.info(f" 🔹 训练/验证比例: {train_size / total_size:.1%}/{val_size / total_size:.1%}") # 为训练集创建加权采样器 # 注意:需要从原始数据集获取权重 train_indices = train_dataset.indices train_weights = full_dataset.sample_weights[train_indices] train_sampler = WeightedRandomSampler( weights=train_weights, num_samples=len(train_weights), replacement=True ) # GPU内存配置 if gpu_memory >= 45: # 48GB batch_size = 16 gradient_accumulation = 2 max_length = 512 dataloader_num_workers = 4 elif gpu_memory >= 30: # 32GB batch_size = 12 gradient_accumulation = 3 max_length = 448 dataloader_num_workers = 3 elif gpu_memory >= 22: # 24GB batch_size = 8 gradient_accumulation = 4 max_length = 384 dataloader_num_workers = 2 else: # 8-16GB batch_size = 4 gradient_accumulation = 8 max_length = 256 dataloader_num_workers = 1 effective_batch_size = batch_size * gradient_accumulation # 平衡的学习率策略 initial_learning_rate = 1.5e-5 # 更保守的学习率 warmup_ratio = 0.1 # 10%预热 # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) # 🎯 关键改进:科学的训练参数配置 training_args = TrainingArguments( output_dir=checkpoint_dir, # 训练配置 num_train_epochs=100, # 减少到100轮,配合Early Stopping per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation, # 🎯 验证和评估策略 eval_strategy="epoch", # 每个epoch评估 save_strategy="epoch", # 每个epoch保存 logging_strategy="steps", logging_steps=50, # 🎯 关键:基于宏平均F1的模型选择 load_best_model_at_end=True, # 加载最佳模型 metric_for_best_model="eval_f1", # 使用宏平均F1作为最佳模型指标 greater_is_better=True, # F1越大越好 # 学习率配置 learning_rate=initial_learning_rate, warmup_ratio=warmup_ratio, lr_scheduler_type="cosine", # 正则化 weight_decay=0.01, # 硬件优化 fp16=True, dataloader_pin_memory=True, dataloader_num_workers=dataloader_num_workers, group_by_length=True, # 保存配置 save_total_limit=3, # 只保留最近3个checkpoint # 其他配置 remove_unused_columns=False, report_to=[], adam_epsilon=1e-8, max_grad_norm=1.0, skip_memory_metrics=True, disable_tqdm=False, # 种子设置 seed=42, data_seed=42, ) logger.info(f"🎯 科学训练参数配置:") logger.info(f" 🔹 最大训练轮数: {training_args.num_train_epochs}") logger.info(f" 🔹 批次大小: {batch_size}") logger.info(f" 🔹 有效批次大小: {effective_batch_size}") logger.info(f" 🔹 学习率: {training_args.learning_rate}") logger.info(f" 🔹 预热比例: {warmup_ratio}") logger.info(f" 🔹 序列长度: {max_length}") logger.info(f" 🔹 验证策略: 每个epoch评估") logger.info(f" 🔹 模型选择: 基于宏平均F1分数") logger.info(f" 🔹 Early Stopping: 8个epoch无改善停止") # 数据整理器 data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # 🎯 关键:初始化回调函数 loss_tracker = LossTracker() validation_metrics_callback = ValidationMetricsCallback() confusion_matrix_callback = ConfusionMatrixCallback( eval_dataset=val_dataset, # 使用验证集 tokenizer=tokenizer, output_dir=checkpoint_dir, epochs_interval=10 # 每10个epoch生成一次 ) # 🎯 关键:Early Stopping - 8个epoch无改善停止 early_stopping_callback = EarlyStoppingCallback( early_stopping_patience=8, # 8个epoch无改善停止 early_stopping_threshold=0.001 # 最小改善阈值 ) # 🎯 使用科学的WeightedTrainer trainer = WeightedTrainer( model=model, args=training_args, train_dataset=train_dataset, # 80%训练集 eval_dataset=val_dataset, # 20%验证集 tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, # 详细指标计算 callbacks=[ loss_tracker, validation_metrics_callback, confusion_matrix_callback, early_stopping_callback ], weighted_sampler=train_sampler # 加权采样器 ) logger.info("🏃‍♂️ 开始科学训练...") logger.info("🎯 科学训练特性:") logger.info(" ✅ 训练/验证集分离 (80:20)") logger.info(" ✅ 基于宏平均F1的模型选择") logger.info(" ✅ Early Stopping (8个epoch耐心值)") logger.info(" ✅ 每个epoch验证和保存") logger.info(" ✅ 详细的验证指标监控") logger.info(" ✅ 平衡的Focal Loss参数") logger.info(" ✅ WeightedRandomSampler") logger.info(" ✅ 余弦退火学习率调度") logger.info(" ✅ 自动选择最佳模型") start_time = datetime.now() try: # 🎯 执行科学训练 trainer.train() logger.info("📊 训练完成,分析最佳模型...") # 获取最佳模型信息 validation_history = validation_metrics_callback.validation_history if validation_history: best_metrics = max(validation_history, key=lambda x: x['eval_f1_macro']) best_epoch = best_metrics['epoch'] best_f1 = best_metrics['eval_f1_macro'] logger.info(f"🏆 最佳模型信息:") logger.info(f" 🔹 最佳epoch: {best_epoch}") logger.info(f" 🔹 最佳宏平均F1: {best_f1:.4f}") logger.info(f" 🔹 验证准确率: {best_metrics['eval_accuracy']:.4f}") logger.info(f" 🔹 少数类F1: {best_metrics['eval_f1_minority']:.4f}") logger.info(f" 🔹 少数类精确率: {best_metrics['eval_precision_minority']:.4f}") logger.info(f" 🔹 少数类召回率: {best_metrics['eval_recall_minority']:.4f}") except RuntimeError as e: if "out of memory" in str(e).lower(): logger.error("❌ GPU内存不足!") logger.error("💡 建议减小批次大小") raise else: raise end_time = datetime.now() training_duration = end_time - start_time # 获取实际训练的epoch数 actual_epochs = len(validation_metrics_callback.validation_history) logger.info(f"🎉 科学训练完成! 耗时: {training_duration}") logger.info(f"📊 实际训练轮数: {actual_epochs} (最大{training_args.num_train_epochs})") # 生成训练可视化图表 logger.info("📈 生成科学训练可视化图表...") plot_training_curves(loss_tracker, validation_metrics_callback.validation_history, checkpoint_dir) # 🎯 保存最佳模型到指定目录 logger.info(f"💾 保存最佳模型到: {output_dir}") model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) # 保存详细的训练历史 training_history = { 'train_losses': loss_tracker.train_losses, 'train_steps': loss_tracker.train_steps, 'eval_losses': loss_tracker.eval_losses, 'eval_steps': loss_tracker.eval_steps, 'validation_metrics': validation_metrics_callback.validation_history } with open(os.path.join(checkpoint_dir, 'training_history.json'), 'w', encoding='utf-8') as f: json.dump(training_history, f, ensure_ascii=False, indent=2) # 保存混淆矩阵历史 cm_history = {epoch: cm.tolist() for epoch, cm in confusion_matrix_callback.confusion_matrices.items()} with open(os.path.join(checkpoint_dir, 'confusion_matrix_history.json'), 'w', encoding='utf-8') as f: json.dump(cm_history, f, ensure_ascii=False, indent=2) # 保存科学训练的详细信息 training_info = { "model_name": model_path, "model_type": "Chinese-RoBERTa-WWM-Ext with Scientific Training", "training_methodology": "macro_f1_based_selection", "training_duration": str(training_duration), "actual_epochs_trained": actual_epochs, "max_epochs_configured": training_args.num_train_epochs, # 数据分割信息 "data_split": { "total_samples": total_size, "train_samples": train_size, "validation_samples": val_size, "train_ratio": train_size / total_size, "validation_ratio": val_size / total_size }, # 数据分布 "label_distribution": label_distribution, "data_imbalance": { "class_0_count": label_distribution.get(0, 0), "class_1_count": label_distribution.get(1, 0), "class_0_ratio": label_distribution.get(0, 0) / total_size, "class_1_ratio": label_distribution.get(1, 0) / total_size, "imbalance_ratio": label_distribution.get(0, 1) / label_distribution.get(1, 1) }, # 平衡的参数设置 "balanced_focal_loss_params": { "alpha_weights": alpha_weights, "gamma": recommended_gamma, "formula": "FL(p_t) = -α_t * (1-p_t)^γ * log(p_t)", "approach": "balanced_macro_f1_focus" }, # 采样策略 "weighted_sampling": { "enabled": True, "strategy": "WeightedRandomSampler", "applied_to": "training_set_only" }, # 科学训练策略 "scientific_training_strategy": { "model_selection_metric": "eval_f1_macro", "early_stopping_patience": 8, "early_stopping_threshold": 0.001, "validation_frequency": "every_epoch", "best_model_loading": True }, # 学习率策略 "learning_strategy": { "initial_learning_rate": initial_learning_rate, "warmup_ratio": warmup_ratio, "lr_scheduler": "cosine", "approach": "conservative_and_stable" }, # 硬件优化 "gpu_optimization": { "gpu_name": torch.cuda.get_device_name(0), "gpu_memory_gb": gpu_memory, "effective_batch_size": effective_batch_size, "sequence_length": max_length, "optimization_level": "scientific_training" }, # 训练参数 "training_args": { "num_train_epochs": training_args.num_train_epochs, "per_device_train_batch_size": training_args.per_device_train_batch_size, "gradient_accumulation_steps": training_args.gradient_accumulation_steps, "learning_rate": training_args.learning_rate, "warmup_ratio": training_args.warmup_ratio, "weight_decay": training_args.weight_decay, "fp16": training_args.fp16, "lr_scheduler_type": training_args.lr_scheduler_type }, # 模型参数 "model_parameters": { "total_params": total_params, "trainable_params": trainable_params, }, # 路径信息 "paths": { "model_input_path": model_path, "model_output_path": output_dir, "checkpoint_output_path": checkpoint_dir }, # 科学训练改进 "scientific_improvements": [ "Train/Validation split (80:20) for unbiased evaluation", "Macro F1 score as primary model selection metric", "Early stopping with 8-epoch patience to prevent overfitting", "Balanced Focal Loss parameters to avoid over-optimization", "Every-epoch validation for detailed monitoring", "Automatic best model selection and loading", "Conservative learning rate for stable convergence", "Comprehensive validation metrics tracking", "WeightedRandomSampler for balanced training", "Cosine annealing learning rate scheduler" ], # 最佳模型信息 "best_model_info": validation_history[-1] if validation_history else None, # 可视化文件 "visualization_files": { "training_validation_curves": "training_validation_curves.png", "validation_confusion_matrices": [f"val_confusion_matrix_epoch_{i}.png" for i in range(10, actual_epochs + 1, 10)], "training_history": "training_history.json", "confusion_matrix_history": "confusion_matrix_history.json" }, "training_completed": end_time.isoformat() } with open(os.path.join(checkpoint_dir, 'scientific_training_info.json'), 'w', encoding='utf-8') as f: json.dump(training_info, f, ensure_ascii=False, indent=2) # 同时在模型目录保存一份摘要 model_summary = { "model_selection_method": "macro_f1_based", "best_epoch": validation_history[-1]['epoch'] if validation_history else actual_epochs, "best_macro_f1": validation_history[-1]['eval_f1_macro'] if validation_history else None, "training_methodology": "scientific_with_early_stopping", "data_split": "80_20_train_validation" } with open(os.path.join(output_dir, 'model_selection_summary.json'), 'w', encoding='utf-8') as f: json.dump(model_summary, f, ensure_ascii=False, indent=2) logger.info("📋 科学训练信息已保存") return trainer, model, tokenizer, loss_tracker, validation_metrics_callback def main(): """主函数""" logger.info("=" * 120) logger.info("🚀 Chinese-RoBERTa-WWM-Ext 科学训练 (基于宏平均F1的模型选择)") logger.info("=" * 120) # 配置路径 train_file = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data/train_dataset.json" model_path = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model" output_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train" checkpoint_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/ouput_result" # 确保所有输出目录存在 os.makedirs(output_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) logger.info(f"📁 确保输出目录存在:") logger.info(f" 🔹 最佳模型输出: {output_dir}") logger.info(f" 🔹 训练记录: {checkpoint_dir}") # 确认第三方报告工具已禁用 logger.info("🚫 确认第三方报告工具状态:") logger.info(f" 🔹 WANDB_DISABLED: {os.environ.get('WANDB_DISABLED', 'not set')}") logger.info(f" 🔹 COMET_MODE: {os.environ.get('COMET_MODE', 'not set')}") logger.info(f" 🔹 NEPTUNE_MODE: {os.environ.get('NEPTUNE_MODE', 'not set')}") logger.info(f" ✅ 所有第三方报告工具已禁用") logger.info(f"\n📋 科学训练配置:") logger.info(f" 🔹 训练数据: {train_file}") logger.info(f" 🔹 基础模型: {model_path}") logger.info(f" 🔹 模型类型: Chinese-RoBERTa-WWM-Ext") logger.info(f" 🔹 训练方法: 科学训练 (Scientific Training)") logger.info(f" 🔹 核心改进:") logger.info(f" • 训练/验证集分离 (80:20)") logger.info(f" • 基于宏平均F1的模型选择") logger.info(f" • Early Stopping (8个epoch耐心值)") logger.info(f" • 每个epoch验证和监控") logger.info(f" • 平衡的Focal Loss参数") logger.info(f" • 自动最佳模型选择") logger.info(f" 🔹 最大训练轮数: 100 epochs (早停可能提前结束)") logger.info(f" 🔹 最佳模型输出: {output_dir}") logger.info(f" 🔹 训练记录: {checkpoint_dir}") # 加载训练数据 train_data = load_training_data(train_file) if train_data is None: logger.error("❌ 无法加载训练数据,程序退出") return try: # 执行科学训练 trainer, model, tokenizer, loss_tracker, validation_callback = train_roberta_model( train_data, model_path=model_path, output_dir=output_dir, checkpoint_dir=checkpoint_dir ) logger.info("=" * 120) logger.info("🎉 科学训练完成!") logger.info("=" * 120) logger.info(f"📁 文件输出位置:") logger.info(f" 🔹 最佳模型: {output_dir}") logger.info(f" 🔹 训练记录和图表: {checkpoint_dir}") logger.info("📄 生成的文件:") logger.info(" 最佳模型文件 (model_train目录):") logger.info(" • pytorch_model.bin - 基于宏平均F1选择的最佳模型") logger.info(" • config.json - 科学训练模型配置") logger.info(" • tokenizer配置文件") logger.info(" • model_selection_summary.json - 模型选择摘要") logger.info(" 科学训练记录 (ouput_result目录):") logger.info(" • scientific_training_info.json - 完整科学训练信息") logger.info(" • training_history.json - 训练和验证历史") logger.info(" • confusion_matrix_history.json - 混淆矩阵历史") logger.info(" • training_validation_curves.png - 训练验证曲线") logger.info(" • val_confusion_matrix_epoch_X.png - 验证集混淆矩阵") logger.info(" • checkpoint-* - 训练检查点") logger.info("🔥 科学训练特性:") logger.info(" ✅ Chinese-RoBERTa-WWM-Ext 基础模型") logger.info(" ✅ 数据科学方法: 80:20 训练验证分离") logger.info(" ✅ 智能模型选择: 基于宏平均F1分数") logger.info(" ✅ 防过拟合: 8个epoch Early Stopping") logger.info(" ✅ 平衡优化: 温和的Focal Loss参数") logger.info(" ✅ 全程监控: 每个epoch验证评估") logger.info(" ✅ 自动化选择: 最佳模型自动保存") logger.info(" ✅ WeightedRandomSampler 平衡采样") logger.info(" ✅ 余弦退火学习率调度") logger.info(" ✅ 完整可视化和指标追踪") logger.info("🎯 科学方法优势:") logger.info(" ⚡ 无偏验证评估确保泛化能力") logger.info(" ⚡ 基于目标指标的智能模型选择") logger.info(" ⚡ 早停机制防止过拟合") logger.info(" ⚡ 平衡参数避免过度优化") logger.info(" ⚡ 全程监控确保训练质量") # 显示最佳模型信息 if validation_callback.validation_history: best_metrics = max(validation_callback.validation_history, key=lambda x: x['eval_f1_macro']) logger.info(f"\n🏆 最终选择的最佳模型:") logger.info(f" 🔹 来源epoch: {best_metrics['epoch']}") logger.info(f" 🔹 宏平均F1: {best_metrics['eval_f1_macro']:.4f}") logger.info(f" 🔹 验证准确率: {best_metrics['eval_accuracy']:.4f}") logger.info(f" 🔹 少数类F1: {best_metrics['eval_f1_minority']:.4f}") logger.info(f" 🔹 少数类精确率: {best_metrics['eval_precision_minority']:.4f}") logger.info(f" 🔹 少数类召回率: {best_metrics['eval_recall_minority']:.4f}") # 显示文件详情 logger.info(f"\n📂 文件保存详情:") logger.info(f"📋 最佳模型 ({output_dir}):") try: for file in os.listdir(output_dir): file_path = os.path.join(output_dir, file) if os.path.isfile(file_path): file_size = os.path.getsize(file_path) / (1024 * 1024) logger.info(f" 📄 {file} ({file_size:.2f} MB)") except Exception as e: logger.warning(f" ⚠️ 无法列出模型文件: {str(e)}") logger.info(f"📋 训练记录 ({checkpoint_dir}):") try: files = os.listdir(checkpoint_dir) json_files = [f for f in files if f.endswith('.json')] png_files = [f for f in files if f.endswith('.png')] checkpoint_dirs = [f for f in files if f.startswith('checkpoint-')] if json_files: logger.info(" 配置和历史文件:") for file in sorted(json_files): file_path = os.path.join(checkpoint_dir, file) file_size = os.path.getsize(file_path) / 1024 logger.info(f" 📄 {file} ({file_size:.1f} KB)") if png_files: logger.info(" 可视化图表:") for file in sorted(png_files): file_path = os.path.join(checkpoint_dir, file) file_size = os.path.getsize(file_path) / 1024 logger.info(f" 📊 {file} ({file_size:.1f} KB)") if checkpoint_dirs: logger.info(" 模型检查点:") for dir_name in sorted(checkpoint_dirs): logger.info(f" 📁 {dir_name}/") except Exception as e: logger.warning(f" ⚠️ 无法列出训练记录: {str(e)}") logger.info("\n🎯 科学训练完成,最佳模型已自动选择并保存!") logger.info("💡 建议:在测试集上评估选择的最佳模型以验证泛化能力") except Exception as e: logger.error(f"❌ 科学训练过程中出现错误: {str(e)}") import traceback traceback.print_exc() raise if __name__ == "__main__": main()