import torch import torch.nn as nn import torch.nn.functional as F import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from sklearn.metrics import accuracy_score, classification_report, confusion_matrix from sklearn.model_selection import train_test_split from transformers import ( BertTokenizer, BertForSequenceClassification, BertModel, BertConfig, TrainingArguments, Trainer, DataCollatorWithPadding, TrainerCallback, EarlyStoppingCallback ) from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler import logging import os from datetime import datetime import math import json from collections import defaultdict, Counter import copy # 禁用wandb和其他第三方报告工具 os.environ["WANDB_DISABLED"] = "true" os.environ["COMET_MODE"] = "disabled" os.environ["NEPTUNE_MODE"] = "disabled" # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 设置matplotlib中文字体 plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False def check_gpu_availability(): """检查GPU可用性 - V100 48GB优化版""" if not torch.cuda.is_available(): raise RuntimeError("❌ GPU不可用!此脚本需要GPU支持。") gpu_count = torch.cuda.device_count() gpu_name = torch.cuda.get_device_name(0) gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 logger.info(f"✅ GPU检查通过!") logger.info(f" 🔹 可用GPU数量: {gpu_count}") logger.info(f" 🔹 GPU型号: {gpu_name}") logger.info(f" 🔹 GPU内存: {gpu_memory:.1f} GB") # V100优化设置 + 对抗训练优化 torch.cuda.empty_cache() torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False # 对抗训练时允许非确定性以提升性能 # 对抗训练内存优化 if gpu_memory >= 40: # V100 48GB logger.info("🚀 检测到V100大显存GPU,启用对抗训练优化配置") return True, gpu_memory, "v100_48gb" else: logger.info("⚠️ 检测到较小显存GPU,将使用保守的对抗训练配置") return True, gpu_memory, "standard" class AdversarialConfig: """对抗训练配置类 - 温和的扰动强度设置""" def __init__(self, imbalance_ratio=6.0, gpu_type="v100_48gb"): self.imbalance_ratio = imbalance_ratio self.gpu_type = gpu_type # 使用更温和的扰动强度 self.majority_class_eps = 0.05 # 降低多数类扰动 self.minority_class_eps = 0.1 # 降低少数类扰动,仍保持2倍差异 self.eps_ratio = 2.0 # 使用原始的Focal Loss参数,不做修改 self.focal_alpha = [0.1, 0.9] # 保持原始权重 self.focal_gamma = 3.0 # 保持原始gamma值 # 更温和的FreeLB参数 self.adv_steps = 3 # 减少对抗步数 self.adv_lr = 0.01 # 降低对抗学习率 self.norm_type = "l2" self.max_norm = 0.5 # 降低最大范数约束 # 温和的训练阶段参数 self.phase_configs = { 'adaptation': { # 适应期 (1-10 epochs) - 延长适应期 'eps': {'majority': 0.03, 'minority': 0.06}, # 非常温和的开始 'clean_ratio': 0.9, # 更多干净样本 'learning_rate': 2e-5, # 使用原始学习率 'adv_steps': 2 # 最少的对抗步数 }, 'enhancement': { # 强化期 (11-25 epochs) 'eps': {'majority': 0.05, 'minority': 0.1}, # 温和提升 'clean_ratio': 0.7, # 仍然更多干净样本 'learning_rate': 1.5e-5, # 温和降低 'adv_steps': 3 }, 'refinement': { # 精调期 (26+ epochs) 'eps': {'majority': 0.04, 'minority': 0.08}, # 回调扰动 'clean_ratio': 0.8, # 精调时更多干净样本 'learning_rate': 1e-5, # 精调阶段 'adv_steps': 2 } } # 更温和的类别差异化采样 self.sampling_configs = { 'adaptation': {'majority': 0.8, 'minority': 0.2}, # 接近原始分布 'enhancement': {'majority': 0.7, 'minority': 0.3}, # 温和调整 'refinement': {'majority': 0.6, 'minority': 0.4} # 最终稍微平衡 } logger.info(f"🎯 温和对抗训练配置初始化完成:") logger.info(f" 🔹 数据不平衡比例: {imbalance_ratio}:1") logger.info(f" 🔹 Focal Loss权重: {self.focal_alpha} (原始设置)") logger.info(f" 🔹 Focal Loss gamma: {self.focal_gamma} (原始设置)") logger.info(f" 🔹 温和扰动: 多数类{self.majority_class_eps} vs 少数类{self.minority_class_eps}") logger.info(f" 🔹 对抗步数: {self.adv_steps} (温和设置)") logger.info(f" 🔹 对抗学习率: {self.adv_lr} (温和设置)") logger.info(f" 🔹 GPU配置: {gpu_type}") class AdversarialPerturbationGenerator: """FreeLB对抗扰动生成器""" def __init__(self, config: AdversarialConfig): self.config = config def generate_perturbation(self, model, inputs, labels, current_phase='enhancement'): """生成对抗扰动""" device = next(model.parameters()).device # 获取当前阶段配置 phase_config = self.config.phase_configs[current_phase] # 获取输入嵌入 input_ids = inputs['input_ids'].to(device) attention_mask = inputs['attention_mask'].to(device) # 获取嵌入层 embedding_layer = model.roberta.embeddings.word_embeddings embeds = embedding_layer(input_ids) # 根据标签确定扰动强度 batch_eps = [] for label in labels: if label.item() == 0: # 多数类 batch_eps.append(phase_config['eps']['majority']) else: # 少数类 batch_eps.append(phase_config['eps']['minority']) batch_eps = torch.tensor(batch_eps, device=device).unsqueeze(-1).unsqueeze(-1) # 初始化扰动 delta = torch.zeros_like(embeds).uniform_(-1, 1) delta = delta * batch_eps * 0.1 # 初始化为小扰动 delta.requires_grad_(True) # 多步对抗扰动生成 for step in range(phase_config['adv_steps']): # 前向传播 perturbed_embeds = embeds + delta # 替换嵌入层输出 outputs = model( inputs_embeds=perturbed_embeds, attention_mask=attention_mask, labels=labels ) loss = outputs['loss'] # 计算梯度 grad = torch.autograd.grad( loss, delta, retain_graph=True if step < phase_config['adv_steps'] - 1 else False )[0] # 归一化梯度 if self.config.norm_type == "l2": grad_norm = torch.norm(grad, dim=-1, keepdim=True) grad = grad / (grad_norm + 1e-8) else: # linf grad = grad.sign() # 更新扰动 delta = delta + self.config.adv_lr * grad # 投影到约束球 if self.config.norm_type == "l2": delta_norm = torch.norm(delta, dim=-1, keepdim=True) delta = delta / torch.max(delta_norm / batch_eps, torch.ones_like(delta_norm)) else: delta = torch.clamp(delta, -batch_eps, batch_eps) delta = delta.detach() delta.requires_grad_(True) return embeds + delta.detach() class LossTracker(TrainerCallback): """增强的损失跟踪回调类 - 对抗训练版""" def __init__(self): self.train_losses = [] self.eval_losses = [] self.eval_accuracies = [] self.eval_minority_f1 = [] self.train_steps = [] self.eval_steps = [] self.epochs = [] self.current_epoch = 0 self.best_eval_accuracy = 0.0 self.best_minority_f1 = 0.0 self.best_epoch = 0 # 对抗训练特有指标 self.adversarial_losses = [] self.clean_losses = [] self.perturbation_norms = [] def on_log(self, args, state, control, logs=None, **kwargs): if logs: if 'loss' in logs: self.train_losses.append(logs['loss']) self.train_steps.append(state.global_step) if 'eval_loss' in logs: self.eval_losses.append(logs['eval_loss']) self.eval_steps.append(state.global_step) if 'eval_accuracy' in logs: self.eval_accuracies.append(logs['eval_accuracy']) if 'eval_minority_f1' in logs: self.eval_minority_f1.append(logs['eval_minority_f1']) # 更新最佳少数类F1 if logs['eval_minority_f1'] > self.best_minority_f1: self.best_minority_f1 = logs['eval_minority_f1'] self.best_epoch = self.current_epoch if 'eval_accuracy' in logs: if logs['eval_accuracy'] > self.best_eval_accuracy: self.best_eval_accuracy = logs['eval_accuracy'] def on_epoch_end(self, args, state, control, **kwargs): self.current_epoch = state.epoch self.epochs.append(state.epoch) class AdversarialValidationCallback(TrainerCallback): """对抗训练验证回调 - 每5个epoch生成详细分析""" def __init__(self, eval_dataset, tokenizer, output_dir, adv_config, epochs_interval=5): self.eval_dataset = eval_dataset self.tokenizer = tokenizer self.output_dir = output_dir self.adv_config = adv_config self.epochs_interval = epochs_interval self.confusion_matrices = {} self.detailed_metrics = {} def on_epoch_end(self, args, state, control, model=None, **kwargs): current_epoch = int(state.epoch) if current_epoch % self.epochs_interval == 0 or current_epoch == args.num_train_epochs: logger.info(f"📊 生成第{current_epoch}轮对抗训练验证分析...") model.eval() predictions = [] true_labels = [] prediction_probs = [] device = next(model.parameters()).device with torch.no_grad(): for i in range(len(self.eval_dataset)): item = self.eval_dataset[i] input_ids = item['input_ids'].unsqueeze(0).to(device) attention_mask = item['attention_mask'].unsqueeze(0).to(device) outputs = model(input_ids=input_ids, attention_mask=attention_mask) logits = outputs['logits'] probs = F.softmax(logits, dim=-1) pred = torch.argmax(logits, dim=-1).cpu().item() predictions.append(pred) true_labels.append(item['labels'].item()) prediction_probs.append(probs.cpu().numpy()) # 计算详细指标 cm = confusion_matrix(true_labels, predictions) self.confusion_matrices[current_epoch] = cm # 计算类别特定指标 if cm.shape == (2, 2): tn, fp, fn, tp = cm.ravel() else: # 处理可能的单类别或异常情况 logger.warning(f"Epoch {current_epoch}: 混淆矩阵形状异常: {cm.shape}") if len(np.unique(true_labels)) == 1: # 只有一个类别的情况 if true_labels[0] == 0: tn, fp, fn, tp = len(true_labels), 0, 0, 0 else: tn, fp, fn, tp = 0, 0, 0, len(true_labels) else: # 其他异常情况 tn, fp, fn, tp = 0, 0, 0, 0 # 少数类(1)指标 minority_precision = tp / (tp + fp) if (tp + fp) > 0 else 0 minority_recall = tp / (tp + fn) if (tp + fn) > 0 else 0 minority_f1 = 2 * (minority_precision * minority_recall) / (minority_precision + minority_recall) if (minority_precision + minority_recall) > 0 else 0 # 多数类(0)指标 majority_precision = tn / (tn + fn) if (tn + fn) > 0 else 0 majority_recall = tn / (tn + fp) if (tn + fp) > 0 else 0 majority_f1 = 2 * (majority_precision * majority_recall) / (majority_precision + majority_recall) if (majority_precision + majority_recall) > 0 else 0 overall_accuracy = (tp + tn) / (tp + tn + fp + fn) balanced_accuracy = (minority_recall + majority_recall) / 2 metrics = { 'overall_accuracy': overall_accuracy, 'balanced_accuracy': balanced_accuracy, 'minority_precision': minority_precision, 'minority_recall': minority_recall, 'minority_f1': minority_f1, 'majority_precision': majority_precision, 'majority_recall': majority_recall, 'majority_f1': majority_f1 } self.detailed_metrics[current_epoch] = metrics logger.info(f" 📈 Epoch {current_epoch} 详细指标:") logger.info(f" 总体准确率: {overall_accuracy:.4f}") logger.info(f" 平衡准确率: {balanced_accuracy:.4f}") logger.info(f" 少数类F1: {minority_f1:.4f}") logger.info(f" 多数类F1: {majority_f1:.4f}") self.save_adversarial_analysis(cm, metrics, current_epoch) model.train() def save_adversarial_analysis(self, cm, metrics, epoch): """保存对抗训练分析图表""" fig, axes = plt.subplots(2, 2, figsize=(15, 12)) # 混淆矩阵 sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0,0], xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)']) axes[0,0].set_title(f'Adversarial Validation CM - Epoch {epoch}') axes[0,0].set_xlabel('Predicted Label') axes[0,0].set_ylabel('True Label') # 类别性能对比 categories = ['Majority Class', 'Minority Class'] precision_scores = [metrics['majority_precision'], metrics['minority_precision']] recall_scores = [metrics['majority_recall'], metrics['minority_recall']] f1_scores = [metrics['majority_f1'], metrics['minority_f1']] x = np.arange(len(categories)) width = 0.25 axes[0,1].bar(x - width, precision_scores, width, label='Precision', alpha=0.8) axes[0,1].bar(x, recall_scores, width, label='Recall', alpha=0.8) axes[0,1].bar(x + width, f1_scores, width, label='F1-Score', alpha=0.8) axes[0,1].set_xlabel('Class') axes[0,1].set_ylabel('Score') axes[0,1].set_title(f'Class Performance - Epoch {epoch}') axes[0,1].set_xticks(x) axes[0,1].set_xticklabels(categories) axes[0,1].legend() axes[0,1].set_ylim(0, 1) # 准确率趋势 if len(self.detailed_metrics) > 1: epochs_list = sorted(self.detailed_metrics.keys()) overall_acc = [self.detailed_metrics[e]['overall_accuracy'] for e in epochs_list] balanced_acc = [self.detailed_metrics[e]['balanced_accuracy'] for e in epochs_list] minority_f1_trend = [self.detailed_metrics[e]['minority_f1'] for e in epochs_list] axes[1,0].plot(epochs_list, overall_acc, 'b-o', label='Overall Accuracy', linewidth=2) axes[1,0].plot(epochs_list, balanced_acc, 'g-o', label='Balanced Accuracy', linewidth=2) axes[1,0].plot(epochs_list, minority_f1_trend, 'r-o', label='Minority F1', linewidth=2) axes[1,0].set_xlabel('Epoch') axes[1,0].set_ylabel('Score') axes[1,0].set_title('Adversarial Training Progress') axes[1,0].legend() axes[1,0].grid(True, alpha=0.3) # 类别不平衡可视化 if cm.shape == (2, 2): class_counts = [cm[0,0] + cm[0,1], cm[1,0] + cm[1,1]] # 真实类别数量 predicted_counts = [cm[0,0] + cm[1,0], cm[0,1] + cm[1,1]] # 预测类别数量 x_pos = [0, 1] axes[1,1].bar([x - 0.2 for x in x_pos], class_counts, 0.4, label='True Distribution', alpha=0.7) axes[1,1].bar([x + 0.2 for x in x_pos], predicted_counts, 0.4, label='Predicted Distribution', alpha=0.7) axes[1,1].set_xlabel('Class') axes[1,1].set_ylabel('Count') axes[1,1].set_title('Class Distribution Analysis') axes[1,1].set_xticks(x_pos) axes[1,1].set_xticklabels(['Majority (0)', 'Minority (1)']) axes[1,1].legend() else: # 处理异常混淆矩阵的情况 axes[1,1].text(0.5, 0.5, f'Confusion Matrix Shape: {cm.shape}', ha='center', va='center', transform=axes[1,1].transAxes) axes[1,1].set_title('Class Distribution Analysis') # 添加性能文本 textstr = f'''Epoch {epoch} Performance: Overall Acc: {metrics['overall_accuracy']:.4f} Balanced Acc: {metrics['balanced_accuracy']:.4f} Minority F1: {metrics['minority_f1']:.4f} Minority Recall: {metrics['minority_recall']:.4f}''' props = dict(boxstyle='round', facecolor='wheat', alpha=0.8) fig.text(0.02, 0.98, textstr, transform=fig.transFigure, fontsize=10, verticalalignment='top', bbox=props) plt.tight_layout() save_path = os.path.join(self.output_dir, f'adversarial_analysis_epoch_{epoch}.png') plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() logger.info(f" 💾 对抗训练分析图表已保存: {save_path}") class SentencePairDataset(Dataset): """句子对数据集类 - 对抗训练优化版""" def __init__(self, data, tokenizer, max_length=512, is_validation=False, adversarial_config=None): self.data = data self.tokenizer = tokenizer self.max_length = max_length self.is_validation = is_validation self.adversarial_config = adversarial_config self.valid_data = [item for item in data if item['label'] in [0, 1]] dataset_type = "验证" if is_validation else "训练" logger.info(f"原始{dataset_type}数据: {len(data)} 条,有效数据: {len(self.valid_data)} 条") self.sentence1_list = [item['sentence1'] for item in self.valid_data] self.sentence2_list = [item['sentence2'] for item in self.valid_data] self.labels = [item['label'] for item in self.valid_data] # 对抗训练专用的类别分析 if not is_validation and adversarial_config: self.class_counts = Counter(self.labels) self.class_weights = self._compute_adversarial_class_weights() self.sample_weights = self._compute_adversarial_sample_weights() logger.info(f"对抗训练类别权重: {self.class_weights}") def _compute_adversarial_class_weights(self): """计算对抗训练专用的类别权重 - 保持原始权重比例""" total_samples = len(self.labels) class_weights = {} # 使用与原始训练相同的权重计算方法 for label in [0, 1]: count = self.class_counts[label] class_weights[label] = total_samples / (2 * count) # 标准的平衡权重计算 return class_weights def _compute_adversarial_sample_weights(self): """计算对抗训练专用的样本权重""" sample_weights = [] for label in self.labels: sample_weights.append(self.class_weights[label]) return torch.tensor(sample_weights, dtype=torch.float) def get_weighted_sampler(self): """返回对抗训练优化的加权随机采样器""" if self.is_validation: raise ValueError("验证集不需要加权采样器") return WeightedRandomSampler( weights=self.sample_weights, num_samples=len(self.sample_weights), replacement=True ) def __len__(self): return len(self.valid_data) def __getitem__(self, idx): sentence1 = str(self.sentence1_list[idx]) sentence2 = str(self.sentence2_list[idx]) label = self.labels[idx] encoding = self.tokenizer( sentence1, sentence2, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) } def load_and_split_data(train_file, validation_split=0.2, random_state=42): """加载数据并划分训练集和验证集 - 对抗训练版""" try: with open(train_file, 'r', encoding='utf-8') as f: all_data = json.load(f) logger.info(f"成功加载原始数据: {len(all_data)} 条记录") valid_data = [item for item in all_data if item['label'] in [0, 1]] logger.info(f"有效数据: {len(valid_data)} 条记录") labels = [item['label'] for item in valid_data] train_data, val_data = train_test_split( valid_data, test_size=validation_split, random_state=random_state, stratify=labels ) logger.info(f"对抗训练数据划分完成:") logger.info(f" 🔹 训练集: {len(train_data)} 条") logger.info(f" 🔹 验证集: {len(val_data)} 条") # 详细的类别不平衡分析 train_labels = [item['label'] for item in train_data] val_labels = [item['label'] for item in val_data] train_counts = Counter(train_labels) val_counts = Counter(val_labels) imbalance_ratio = train_counts[0] / train_counts[1] logger.info(f"📊 数据不平衡分析:") logger.info(f" 训练集 - 多数类:{train_counts[0]} 少数类:{train_counts[1]} 比例:{imbalance_ratio:.1f}:1") logger.info(f" 验证集 - 多数类:{val_counts[0]} 少数类:{val_counts[1]} 比例:{val_counts[0]/val_counts[1]:.1f}:1") return train_data, val_data, imbalance_ratio except Exception as e: logger.error(f"加载和划分数据失败: {str(e)}") return None, None, None def compute_adversarial_metrics(eval_pred): """计算对抗训练专用评估指标""" predictions, labels = eval_pred if isinstance(predictions, (list, tuple)): predictions = predictions[0] if not isinstance(predictions, np.ndarray): predictions = np.array(predictions) if len(predictions.shape) == 3: predictions = predictions[:, -1, :] elif len(predictions.shape) == 1: predictions = predictions.reshape(-1, 2) if len(predictions.shape) != 2: predictions = predictions.reshape(-1, 2) try: predictions = np.argmax(predictions, axis=1) except Exception as e: logger.error(f"预测处理错误: {e}") predictions = np.array([np.argmax(pred) if len(pred) > 1 else 0 for pred in predictions]) # 确保predictions和labels长度一致 min_length = min(len(predictions), len(labels)) predictions = predictions[:min_length] labels = labels[:min_length] logger.info(f"评估样本数量: predictions={len(predictions)}, labels={len(labels)}") # 计算详细指标 accuracy = accuracy_score(labels, predictions) # 计算混淆矩阵,确保样本数量一致 try: cm = confusion_matrix(labels, predictions) if cm.shape == (2, 2): tn, fp, fn, tp = cm.ravel() else: # 处理可能的单类别情况 logger.warning(f"混淆矩阵形状异常: {cm.shape}") if len(np.unique(labels)) == 1: # 只有一个类别的情况 if labels[0] == 0: tn, fp, fn, tp = len(labels), 0, 0, 0 else: tn, fp, fn, tp = 0, 0, 0, len(labels) else: # 其他异常情况,使用默认值 tn, fp, fn, tp = 0, 0, 0, 0 except Exception as e: logger.error(f"混淆矩阵计算错误: {e}") tn, fp, fn, tp = 0, 0, 0, 0 # 少数类指标 minority_precision = tp / (tp + fp) if (tp + fp) > 0 else 0 minority_recall = tp / (tp + fn) if (tp + fn) > 0 else 0 minority_f1 = 2 * (minority_precision * minority_recall) / (minority_precision + minority_recall) if (minority_precision + minority_recall) > 0 else 0 # 平衡准确率 majority_recall = tn / (tn + fp) if (tn + fp) > 0 else 0 balanced_accuracy = (minority_recall + majority_recall) / 2 return { 'accuracy': accuracy, 'balanced_accuracy': balanced_accuracy, 'minority_precision': minority_precision, 'minority_recall': minority_recall, 'minority_f1': minority_f1, } class FocalLoss(nn.Module): """保持原始设置的Focal Loss""" def __init__(self, alpha=None, gamma=3.0, reduction='mean'): # 使用原始gamma=3.0 super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) if self.alpha is not None: if self.alpha.type() != inputs.data.type(): self.alpha = self.alpha.type_as(inputs.data) at = self.alpha.gather(0, targets.data.view(-1)) ce_loss = ce_loss * at focal_weight = (1 - pt) ** self.gamma focal_loss = focal_weight * ce_loss if self.reduction == 'mean': return focal_loss.mean() elif self.reduction == 'sum': return focal_loss.sum() else: return focal_loss class ScaledDotProductAttention(nn.Module): """缩放点积注意力机制 - 对抗训练优化""" def __init__(self, d_model, dropout=0.1): super(ScaledDotProductAttention, self).__init__() self.d_model = d_model self.dropout = nn.Dropout(dropout) def forward(self, query, key, value, mask=None): batch_size, seq_len, d_model = query.size() scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) if mask is not None: mask_value = torch.finfo(scores.dtype).min scores = scores.masked_fill(mask == 0, mask_value) attention_weights = F.softmax(scores, dim=-1) attention_weights = self.dropout(attention_weights) output = torch.matmul(attention_weights, value) return output, attention_weights class RoBERTaWithAdversarialTraining(nn.Module): """带对抗训练的RoBERTa模型 - 保持原始参数设置""" def __init__(self, model_path, num_labels=2, dropout=0.1, focal_alpha=None, focal_gamma=3.0, adversarial_config=None): # 保持原始gamma=3.0 super(RoBERTaWithAdversarialTraining, self).__init__() self.roberta = BertModel.from_pretrained(model_path) self.config = self.roberta.config self.config.num_labels = num_labels self.adversarial_config = adversarial_config self.scaled_attention = ScaledDotProductAttention( d_model=self.config.hidden_size, dropout=dropout ) self.dropout = nn.Dropout(dropout) self.classifier = nn.Linear(self.config.hidden_size, num_labels) self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) # 对抗训练组件 self.perturbation_generator = AdversarialPerturbationGenerator(adversarial_config) self._init_weights() self.focal_alpha = focal_alpha self.focal_gamma = focal_gamma self.current_phase = 'adaptation' def _init_weights(self): """初始化新增层的权重""" nn.init.normal_(self.classifier.weight, std=0.02) nn.init.zeros_(self.classifier.bias) def set_training_phase(self, phase): """设置训练阶段""" if self.current_phase != phase: self.current_phase = phase # 只在阶段真正改变时记录日志 def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, inputs_embeds=None, adversarial_training=False): if adversarial_training and labels is not None and self.training: return self._adversarial_forward(input_ids, attention_mask, labels) else: return self._standard_forward(input_ids, attention_mask, token_type_ids, labels, inputs_embeds) def _standard_forward(self, input_ids, attention_mask, token_type_ids, labels, inputs_embeds): """标准前向传播""" roberta_outputs = self.roberta( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, return_dict=True ) sequence_output = roberta_outputs.last_hidden_state enhanced_output, attention_weights = self.scaled_attention( query=sequence_output, key=sequence_output, value=sequence_output, mask=attention_mask.unsqueeze(1) if attention_mask is not None else None ) cls_output = enhanced_output[:, 0, :] cls_output = self.dropout(cls_output) logits = self.classifier(cls_output) loss = None if labels is not None: loss = self.focal_loss(logits, labels) return { 'loss': loss, 'logits': logits, 'hidden_states': enhanced_output, 'attention_weights': attention_weights } def _adversarial_forward(self, input_ids, attention_mask, labels): """对抗训练前向传播""" inputs = {'input_ids': input_ids, 'attention_mask': attention_mask} # 生成对抗扰动 perturbed_embeds = self.perturbation_generator.generate_perturbation( self, inputs, labels, self.current_phase ) # 使用扰动后的嵌入进行前向传播 return self._standard_forward( input_ids=None, attention_mask=attention_mask, token_type_ids=None, labels=labels, inputs_embeds=perturbed_embeds ) def save_pretrained(self, save_directory): """保存模型""" os.makedirs(save_directory, exist_ok=True) model_to_save = self.module if hasattr(self, 'module') else self torch.save(model_to_save.state_dict(), os.path.join(save_directory, 'pytorch_model.bin')) config_dict = { 'model_type': 'RoBERTaWithAdversarialTraining', 'base_model': 'chinese-roberta-wwm-ext', 'num_labels': self.config.num_labels, 'hidden_size': self.config.hidden_size, 'focal_alpha': self.focal_alpha.tolist() if self.focal_alpha is not None else None, 'focal_gamma': self.focal_gamma, 'has_scaled_attention': True, 'has_focal_loss': True, 'has_adversarial_training': True, 'adversarial_config': { 'imbalance_ratio': self.adversarial_config.imbalance_ratio, 'eps_ratio': self.adversarial_config.eps_ratio, 'adv_steps': self.adversarial_config.adv_steps, 'norm_type': self.adversarial_config.norm_type }, 'optimization_level': 'adversarial_training_with_best_validation' } with open(os.path.join(save_directory, 'config.json'), 'w', encoding='utf-8') as f: json.dump(config_dict, f, ensure_ascii=False, indent=2) class AdversarialTrainer(Trainer): """自定义对抗训练器""" def __init__(self, adversarial_config, weighted_sampler=None, *args, **kwargs): super().__init__(*args, **kwargs) self.adversarial_config = adversarial_config self.weighted_sampler = weighted_sampler self.current_epoch = 0 self.current_phase = 'adaptation' # 初始化当前阶段 self.phase_transitions = {8: 'enhancement', 25: 'refinement'} self._phase_logged = False # 避免重复日志 def get_train_dataloader(self): if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_dataset = self.train_dataset if self.weighted_sampler is not None: train_sampler = self.weighted_sampler else: train_sampler = self._get_train_sampler() return DataLoader( train_dataset, batch_size=self.args.train_batch_size, sampler=train_sampler, collate_fn=self.data_collator, drop_last=self.args.dataloader_drop_last, num_workers=self.args.dataloader_num_workers, pin_memory=self.args.dataloader_pin_memory, ) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): """自定义损失计算 - 混合对抗训练""" # 确定当前训练阶段(避免频繁调用) current_phase = self._get_current_phase() if current_phase != self.current_phase: self.current_phase = current_phase model.set_training_phase(current_phase) logger.info(f"🎯 切换到训练阶段: {current_phase}") elif not self._phase_logged: model.set_training_phase(current_phase) self._phase_logged = True # 获取阶段配置 phase_config = self.adversarial_config.phase_configs[current_phase] clean_ratio = phase_config['clean_ratio'] labels = inputs.get("labels") # 混合训练策略 if self.state.global_step % 1 == 0: # 每步都进行混合训练 batch_size = labels.size(0) clean_size = int(batch_size * clean_ratio) if clean_size > 0: # 干净样本训练 clean_inputs = {k: v[:clean_size] for k, v in inputs.items()} clean_outputs = model(**clean_inputs, adversarial_training=False) clean_loss = clean_outputs['loss'] else: clean_loss = 0 if clean_size < batch_size: # 对抗样本训练 adv_inputs = {k: v[clean_size:] for k, v in inputs.items()} adv_outputs = model(**adv_inputs, adversarial_training=True) adv_loss = adv_outputs['loss'] else: adv_loss = 0 # 组合损失 if isinstance(clean_loss, torch.Tensor) and isinstance(adv_loss, torch.Tensor): total_loss = clean_ratio * clean_loss + (1 - clean_ratio) * adv_loss outputs = clean_outputs if clean_size > 0 else adv_outputs elif isinstance(clean_loss, torch.Tensor): total_loss = clean_loss outputs = clean_outputs else: total_loss = adv_loss outputs = adv_outputs outputs['loss'] = total_loss else: # 标准训练 outputs = model(**inputs, adversarial_training=False) total_loss = outputs['loss'] return (total_loss, outputs) if return_outputs else total_loss def _get_current_phase(self): """根据当前epoch确定训练阶段 - 温和对抗训练版本""" current_epoch = int(self.state.epoch) if self.state.epoch else 0 if current_epoch < 10: # 延长适应期 return 'adaptation' elif current_epoch < 25: return 'enhancement' else: return 'refinement' def on_epoch_begin(self, args, state, control, **kwargs): """Epoch开始时的处理""" super().on_epoch_begin(args, state, control, **kwargs) current_epoch = int(state.epoch) # 调整阶段切换点以适应温和训练 phase_transitions = {10: 'enhancement', 25: 'refinement'} # 延长适应期到10轮 # 检查是否需要切换阶段 if current_epoch in phase_transitions: new_phase = phase_transitions[current_epoch] self.current_phase = new_phase self._phase_logged = False # 重置日志标志 logger.info(f"🔄 Epoch {current_epoch}: 切换到训练阶段 '{new_phase}' (温和对抗训练)") # 更新学习率 phase_config = self.adversarial_config.phase_configs[new_phase] new_lr = phase_config['learning_rate'] for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr logger.info(f" 📉 学习率调整为: {new_lr}") def plot_adversarial_training_curves(loss_tracker, output_dir): """绘制对抗训练专用的可视化图表""" plt.figure(figsize=(20, 15)) # 1. 训练损失趋势 if loss_tracker.train_losses: plt.subplot(3, 3, 1) plt.plot(loss_tracker.train_steps, loss_tracker.train_losses, 'b-', label='Training Loss', linewidth=2, alpha=0.8) plt.title('Adversarial Training Loss', fontsize=14, fontweight='bold') plt.xlabel('Training Steps') plt.ylabel('Loss Value') plt.legend() plt.grid(True, alpha=0.3) # 2. 验证损失vs验证准确率 if loss_tracker.eval_losses and loss_tracker.eval_accuracies: plt.subplot(3, 3, 2) ax1 = plt.gca() ax2 = ax1.twinx() line1 = ax1.plot(loss_tracker.eval_steps, loss_tracker.eval_losses, 'r-', label='Validation Loss', linewidth=2) line2 = ax2.plot(loss_tracker.eval_steps, loss_tracker.eval_accuracies, 'g-', label='Validation Accuracy', linewidth=2) ax1.set_xlabel('Training Steps') ax1.set_ylabel('Loss', color='r') ax2.set_ylabel('Accuracy', color='g') ax1.set_title('Validation Metrics') lines = line1 + line2 labels = [l.get_label() for l in lines] ax1.legend(lines, labels, loc='center right') ax1.grid(True, alpha=0.3) # 3. 少数类F1趋势 if loss_tracker.eval_minority_f1: plt.subplot(3, 3, 3) plt.plot(loss_tracker.eval_steps[:len(loss_tracker.eval_minority_f1)], loss_tracker.eval_minority_f1, 'purple', label='Minority F1', linewidth=2, marker='o', markersize=4) plt.title('Minority Class F1 Score Progress', fontsize=14, fontweight='bold') plt.xlabel('Training Steps') plt.ylabel('F1 Score') plt.legend() plt.grid(True, alpha=0.3) # 标记最佳F1 if loss_tracker.best_minority_f1 > 0: plt.axhline(y=loss_tracker.best_minority_f1, color='red', linestyle='--', alpha=0.7) plt.text(0.02, 0.98, f'Best: {loss_tracker.best_minority_f1:.4f}', transform=plt.gca().transAxes, fontsize=10, verticalalignment='top', bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7)) # 4. 训练vs验证准确率对比 if loss_tracker.eval_accuracies: plt.subplot(3, 3, 4) plt.plot(loss_tracker.eval_steps, loss_tracker.eval_accuracies, 'g-o', label='Validation Accuracy', linewidth=2, markersize=3) plt.title('Validation Accuracy Trend', fontsize=14, fontweight='bold') plt.xlabel('Training Steps') plt.ylabel('Accuracy') plt.legend() plt.grid(True, alpha=0.3) # 5. 损失收敛分析 if len(loss_tracker.train_losses) > 50: plt.subplot(3, 3, 5) # 计算移动平均 window_size = 20 train_ma = [] for i in range(window_size, len(loss_tracker.train_losses)): train_ma.append(np.mean(loss_tracker.train_losses[i-window_size:i])) plt.plot(loss_tracker.train_steps[window_size:], train_ma, 'b-', label=f'Training Loss MA({window_size})', linewidth=2) plt.title('Loss Convergence Analysis', fontsize=14, fontweight='bold') plt.xlabel('Training Steps') plt.ylabel('Moving Average Loss') plt.legend() plt.grid(True, alpha=0.3) # 6. 性能提升统计 plt.subplot(3, 3, 6) if loss_tracker.eval_accuracies and loss_tracker.eval_minority_f1: metrics_data = { 'Overall Accuracy': loss_tracker.eval_accuracies[-1] if loss_tracker.eval_accuracies else 0, 'Best Overall Acc': max(loss_tracker.eval_accuracies) if loss_tracker.eval_accuracies else 0, 'Minority F1': loss_tracker.eval_minority_f1[-1] if loss_tracker.eval_minority_f1 else 0, 'Best Minority F1': loss_tracker.best_minority_f1 } metrics_names = list(metrics_data.keys()) metrics_values = list(metrics_data.values()) bars = plt.bar(metrics_names, metrics_values, color=['skyblue', 'lightblue', 'lightcoral', 'coral']) plt.title('Performance Summary', fontsize=14, fontweight='bold') plt.ylabel('Score') plt.xticks(rotation=45) # 添加数值标签 for bar, value in zip(bars, metrics_values): plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, f'{value:.3f}', ha='center', va='bottom') # 7. 训练阶段可视化 plt.subplot(3, 3, 7) if loss_tracker.epochs: phase_colors = [] for epoch in loss_tracker.epochs: if epoch < 8: phase_colors.append('lightgreen') # adaptation elif epoch < 25: phase_colors.append('orange') # enhancement else: phase_colors.append('red') # refinement plt.scatter(loss_tracker.epochs, [1]*len(loss_tracker.epochs), c=phase_colors, s=50, alpha=0.7) plt.title('Training Phases Timeline', fontsize=14, fontweight='bold') plt.xlabel('Epoch') plt.ylabel('Phase') plt.yticks([1], ['Phases']) # 添加图例 from matplotlib.patches import Patch legend_elements = [Patch(facecolor='lightgreen', label='Adaptation (1-8)'), Patch(facecolor='orange', label='Enhancement (9-25)'), Patch(facecolor='red', label='Refinement (26+)')] plt.legend(handles=legend_elements, loc='upper right') # 8. 对抗训练效果对比(如果有基准数据) plt.subplot(3, 3, 8) baseline_acc = 0.9451 # 从你的混淆矩阵获得 current_acc = loss_tracker.eval_accuracies[-1] if loss_tracker.eval_accuracies else 0 improvement = ((current_acc - baseline_acc) / baseline_acc * 100) if baseline_acc > 0 else 0 categories = ['Baseline\n(No Adversarial)', 'Current\n(Adversarial)'] accuracies = [baseline_acc, current_acc] colors = ['lightgray', 'lightgreen' if improvement > 0 else 'lightcoral'] bars = plt.bar(categories, accuracies, color=colors) plt.title('Adversarial Training Impact', fontsize=14, fontweight='bold') plt.ylabel('Accuracy') plt.ylim(0.9, 1.0) # 添加改进百分比 plt.text(0.5, 0.95, f'Improvement: {improvement:+.2f}%', transform=plt.gca().transAxes, ha='center', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgreen" if improvement > 0 else "lightcoral", alpha=0.7)) # 9. 训练时间分析 plt.subplot(3, 3, 9) if loss_tracker.train_steps: step_intervals = [] for i in range(1, len(loss_tracker.train_steps)): step_intervals.append(loss_tracker.train_steps[i] - loss_tracker.train_steps[i-1]) if step_intervals: plt.hist(step_intervals, bins=20, alpha=0.7, color='skyblue', edgecolor='black') plt.title('Training Step Intervals', fontsize=14, fontweight='bold') plt.xlabel('Steps Between Logs') plt.ylabel('Frequency') plt.grid(True, alpha=0.3) plt.tight_layout() # 保存综合对抗训练图表 curves_path = os.path.join(output_dir, 'comprehensive_adversarial_training_curves.png') plt.savefig(curves_path, dpi=300, bbox_inches='tight') plt.close() logger.info(f"📈 对抗训练综合曲线已保存: {curves_path}") # 单独保存重要指标图表 plt.figure(figsize=(15, 10)) # 少数类F1和整体准确率对比 plt.subplot(2, 2, 1) if loss_tracker.eval_accuracies and loss_tracker.eval_minority_f1: min_len = min(len(loss_tracker.eval_accuracies), len(loss_tracker.eval_minority_f1)) steps = loss_tracker.eval_steps[:min_len] acc = loss_tracker.eval_accuracies[:min_len] f1 = loss_tracker.eval_minority_f1[:min_len] plt.plot(steps, acc, 'b-o', label='Overall Accuracy', linewidth=2, markersize=4) plt.plot(steps, f1, 'r-o', label='Minority F1', linewidth=2, markersize=4) plt.title('Key Metrics Comparison', fontsize=16, fontweight='bold') plt.xlabel('Training Steps') plt.ylabel('Score') plt.legend() plt.grid(True, alpha=0.3) # 训练损失平滑曲线 plt.subplot(2, 2, 2) if len(loss_tracker.train_losses) > 10: # 应用高斯平滑 from scipy import ndimage smoothed_loss = ndimage.gaussian_filter1d(loss_tracker.train_losses, sigma=2) plt.plot(loss_tracker.train_steps, smoothed_loss, 'g-', linewidth=3, alpha=0.8) plt.title('Smoothed Training Loss', fontsize=16, fontweight='bold') plt.xlabel('Training Steps') plt.ylabel('Loss') plt.grid(True, alpha=0.3) # 性能改进总结 plt.subplot(2, 2, 3) if loss_tracker.eval_accuracies and loss_tracker.eval_minority_f1: initial_acc = loss_tracker.eval_accuracies[0] final_acc = loss_tracker.eval_accuracies[-1] initial_f1 = loss_tracker.eval_minority_f1[0] if loss_tracker.eval_minority_f1 else 0 final_f1 = loss_tracker.eval_minority_f1[-1] if loss_tracker.eval_minority_f1 else 0 improvements = { 'Overall Accuracy': ((final_acc - initial_acc) / initial_acc * 100) if initial_acc > 0 else 0, 'Minority F1': ((final_f1 - initial_f1) / initial_f1 * 100) if initial_f1 > 0 else 0 } colors = ['green' if imp > 0 else 'red' for imp in improvements.values()] bars = plt.bar(improvements.keys(), improvements.values(), color=colors, alpha=0.7) plt.title('Performance Improvements (%)', fontsize=16, fontweight='bold') plt.ylabel('Improvement (%)') plt.axhline(y=0, color='black', linestyle='-', alpha=0.3) for bar, imp in zip(bars, improvements.values()): plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, f'{imp:+.1f}%', ha='center', va='bottom') # 训练稳定性分析 plt.subplot(2, 2, 4) if len(loss_tracker.eval_accuracies) > 5: # 计算准确率的标准差(稳定性指标) acc_std = np.std(loss_tracker.eval_accuracies[-10:]) # 最后10个评估的标准差 plt.text(0.5, 0.7, f'Recent Stability\n(Last 10 evals)', ha='center', va='center', transform=plt.gca().transAxes, fontsize=14, fontweight='bold') plt.text(0.5, 0.5, f'Std Dev: {acc_std:.4f}', ha='center', va='center', transform=plt.gca().transAxes, fontsize=12) stability_rating = "High" if acc_std < 0.01 else "Medium" if acc_std < 0.02 else "Low" color = "green" if stability_rating == "High" else "orange" if stability_rating == "Medium" else "red" plt.text(0.5, 0.3, f'Rating: {stability_rating}', ha='center', va='center', transform=plt.gca().transAxes, fontsize=12, color=color, fontweight='bold') plt.axis('off') plt.tight_layout() key_metrics_path = os.path.join(output_dir, 'key_adversarial_metrics.png') plt.savefig(key_metrics_path, dpi=300, bbox_inches='tight') plt.close() logger.info(f"📈 关键对抗训练指标图表已保存: {key_metrics_path}") def train_adversarial_roberta_model(train_data, val_data, imbalance_ratio, base_model_path, model_path="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model", output_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_adversarial", checkpoint_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/output_adversarial"): """完整的对抗训练流程 - V100 48GB优化版""" gpu_available, gpu_memory, gpu_type = check_gpu_availability() device = torch.device('cuda') logger.info(f"🚀 使用GPU设备: {device} ({gpu_type})") # 初始化对抗训练配置 adv_config = AdversarialConfig(imbalance_ratio=imbalance_ratio, gpu_type=gpu_type) logger.info(f"📥 基于预训练模型进行对抗训练: {base_model_path}") tokenizer = BertTokenizer.from_pretrained(model_path) # 创建对抗训练优化的Focal Loss alpha_tensor = torch.tensor(adv_config.focal_alpha, dtype=torch.float).to(device) model = RoBERTaWithAdversarialTraining( model_path=model_path, num_labels=2, dropout=0.1, focal_alpha=alpha_tensor, focal_gamma=adv_config.focal_gamma, adversarial_config=adv_config ) # 如果有预训练的模型权重,加载它们 if os.path.exists(base_model_path): logger.info(f"🔄 加载预训练模型权重: {base_model_path}") try: checkpoint = torch.load(os.path.join(base_model_path, 'pytorch_model.bin'), map_location='cpu') model.load_state_dict(checkpoint, strict=False) logger.info("✅ 预训练模型权重加载成功") except Exception as e: logger.warning(f"⚠️ 无法加载预训练权重,使用默认初始化: {str(e)}") model = model.to(device) logger.info(f"✅ 对抗训练模型已加载到GPU: {next(model.parameters()).device}") total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) logger.info(f"📊 对抗训练模型参数统计:") logger.info(f" 🔹 总参数量: {total_params:,}") logger.info(f" 🔹 可训练参数: {trainable_params:,}") logger.info("📋 准备对抗训练数据集...") train_dataset = SentencePairDataset(train_data, tokenizer, max_length=512, is_validation=False, adversarial_config=adv_config) val_dataset = SentencePairDataset(val_data, tokenizer, max_length=512, is_validation=True) weighted_sampler = train_dataset.get_weighted_sampler() logger.info(f" 🔹 训练集大小: {len(train_dataset)}") logger.info(f" 🔹 验证集大小: {len(val_dataset)}") # V100 48GB对抗训练优化配置 if gpu_type == "v100_48gb": batch_size = 12 # 对抗训练需要更多内存,适当降低 gradient_accumulation = 3 # 保持有效批次大小36 max_grad_norm = 1.0 fp16 = True dataloader_num_workers = 4 else: batch_size = 8 gradient_accumulation = 4 max_grad_norm = 0.5 fp16 = True dataloader_num_workers = 2 effective_batch_size = batch_size * gradient_accumulation # 确保输出目录存在 os.makedirs(output_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) training_args = TrainingArguments( output_dir=checkpoint_dir, num_train_epochs=30, # 减少总轮次,温和对抗训练 per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation, eval_strategy="epoch", eval_steps=1, save_strategy="epoch", save_steps=10, # 改为每10个epoch保存一次 logging_strategy="steps", logging_steps=100, # 进一步减少日志频率 warmup_ratio=0.15, # 使用原始warmup比例 weight_decay=0.01, learning_rate=2e-5, # 使用原始学习率 load_best_model_at_end=True, remove_unused_columns=False, dataloader_pin_memory=True, fp16=fp16, dataloader_num_workers=dataloader_num_workers, group_by_length=True, report_to=[], adam_epsilon=1e-8, max_grad_norm=max_grad_norm, save_total_limit=1, # 只保留1个最新checkpoint skip_memory_metrics=True, disable_tqdm=False, lr_scheduler_type="cosine", warmup_steps=0, metric_for_best_model="eval_accuracy", # 使用原始的准确率指标 greater_is_better=True, ) logger.info(f"🎯 对抗训练参数配置:") logger.info(f" 🔹 训练轮数: {training_args.num_train_epochs} (分3个阶段)") logger.info(f" 🔹 批次大小: {batch_size}") logger.info(f" 🔹 梯度累积: {gradient_accumulation}") logger.info(f" 🔹 有效批次大小: {effective_batch_size}") logger.info(f" 🔹 起始学习率: {training_args.learning_rate}") logger.info(f" 🔹 主要评估指标: 少数类F1分数") logger.info(f" 🔹 GPU优化配置: {gpu_type}") logger.info(f"🎯 对抗训练阶段设置:") for phase, config in adv_config.phase_configs.items(): epochs = "1-8" if phase == "adaptation" else "9-25" if phase == "enhancement" else "26+" logger.info(f" 🔹 {phase.capitalize()} ({epochs}): eps={config['eps']}, lr={config['learning_rate']}, clean_ratio={config['clean_ratio']}") data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # 创建回调函数 loss_tracker = LossTracker() adv_validation_callback = AdversarialValidationCallback( eval_dataset=val_dataset, tokenizer=tokenizer, output_dir=checkpoint_dir, adv_config=adv_config, epochs_interval=5 ) trainer = AdversarialTrainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_adversarial_metrics, callbacks=[loss_tracker, adv_validation_callback], adversarial_config=adv_config, weighted_sampler=weighted_sampler ) logger.info("🏃‍♂️ 开始完整的FreeLB对抗训练...") logger.info("🎯 对抗训练特性:") logger.info(" ✅ FreeLB多步对抗扰动生成") logger.info(" ✅ 类别差异化扰动强度 (少数类2.5倍)") logger.info(" ✅ 混合训练策略 (干净+对抗样本)") logger.info(" ✅ 三阶段渐进式训练") logger.info(" ✅ 强化版Focal Loss (gamma=4.0)") logger.info(" ✅ 动态学习率调整") logger.info(" ✅ 少数类F1优化目标") logger.info(" ✅ V100 48GB内存优化") logger.info(" ✅ 实时对抗效果监控") start_time = datetime.now() try: trainer.train() logger.info(f"🏆 对抗训练完成!") logger.info(f" 🔹 最佳少数类F1: {loss_tracker.best_minority_f1:.4f}") logger.info(f" 🔹 最佳整体准确率: {loss_tracker.best_eval_accuracy:.4f}") logger.info(f" 🔹 最佳模型来自: Epoch {loss_tracker.best_epoch}") except RuntimeError as e: if "out of memory" in str(e).lower(): logger.error("❌ GPU内存不足!") logger.error("💡 建议减小批次大小或对抗步数") raise else: raise end_time = datetime.now() training_duration = end_time - start_time logger.info(f"🎉 对抗训练完成! 耗时: {training_duration}") logger.info("📈 生成对抗训练可视化图表...") plot_adversarial_training_curves(loss_tracker, checkpoint_dir) logger.info(f"💾 保存最佳对抗训练模型到: {output_dir}") # 保存最佳模型 trainer.model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) # 保存对抗训练历史记录 adversarial_history = { 'train_losses': loss_tracker.train_losses, 'train_steps': loss_tracker.train_steps, 'eval_losses': loss_tracker.eval_losses, 'eval_steps': loss_tracker.eval_steps, 'eval_accuracies': loss_tracker.eval_accuracies, 'eval_minority_f1': loss_tracker.eval_minority_f1, 'epochs': loss_tracker.epochs, 'best_eval_accuracy': loss_tracker.best_eval_accuracy, 'best_minority_f1': loss_tracker.best_minority_f1, 'best_epoch': loss_tracker.best_epoch, 'adversarial_config': { 'imbalance_ratio': adv_config.imbalance_ratio, 'focal_alpha': adv_config.focal_alpha, 'focal_gamma': adv_config.focal_gamma, 'phase_configs': adv_config.phase_configs, 'sampling_configs': adv_config.sampling_configs } } with open(os.path.join(checkpoint_dir, 'adversarial_training_history.json'), 'w', encoding='utf-8') as f: json.dump(adversarial_history, f, ensure_ascii=False, indent=2) # 保存详细的对抗训练信息 adversarial_info = { "model_name": model_path, "model_type": "Chinese-RoBERTa-WWM-Ext with FreeLB Adversarial Training", "base_model_path": base_model_path, "optimization_level": "freelb_adversarial_training", "training_duration": str(training_duration), "num_train_samples": len(train_dataset), "num_val_samples": len(val_dataset), "data_imbalance_ratio": imbalance_ratio, "adversarial_training_config": { "method": "FreeLB (Free Large-Batch)", "adversarial_steps": adv_config.adv_steps, "perturbation_type": "embedding_level", "norm_type": adv_config.norm_type, "max_norm": adv_config.max_norm, "adversarial_lr": adv_config.adv_lr, "class_differential_perturbation": True, "majority_class_eps_range": "0.08-0.12", "minority_class_eps_range": "0.2-0.3", "eps_ratio": adv_config.eps_ratio }, "training_phases": { "adaptation_phase": { "epochs": "1-8", "description": "温和对抗训练适应期", "clean_ratio": 0.8, "learning_rate": "8e-6", "perturbation_strength": "低" }, "enhancement_phase": { "epochs": "9-25", "description": "强化对抗训练期", "clean_ratio": 0.5, "learning_rate": "5e-6", "perturbation_strength": "中高" }, "refinement_phase": { "epochs": "26+", "description": "精调优化期", "clean_ratio": 0.2, "learning_rate": "2e-6", "perturbation_strength": "中" } }, "best_model_info": { "best_minority_f1": float(loss_tracker.best_minority_f1), "best_overall_accuracy": float(loss_tracker.best_eval_accuracy), "best_epoch": int(loss_tracker.best_epoch), "model_selection_criterion": "minority_f1_score", "load_best_model_at_end": True }, "focal_loss_enhancement": { "alpha_weights": adv_config.focal_alpha, "gamma": adv_config.focal_gamma, "optimization": "adversarial_training_specific", "minority_class_focus": "aggressive_9x_weight" }, "performance_targets": { "primary_target": "minority_class_f1_improvement", "secondary_target": "overall_accuracy_maintenance", "expected_minority_f1_gain": "15-30%", "expected_robustness_improvement": "significant" }, "gpu_optimization": { "gpu_name": torch.cuda.get_device_name(0), "gpu_memory_gb": gpu_memory, "optimization_target": gpu_type, "effective_batch_size": effective_batch_size, "memory_optimization": "adversarial_training_specific", "batch_size_reduction": "12 (from 16 for memory)" }, "training_args": { "num_train_epochs": training_args.num_train_epochs, "per_device_train_batch_size": training_args.per_device_train_batch_size, "gradient_accumulation_steps": training_args.gradient_accumulation_steps, "learning_rate": training_args.learning_rate, "warmup_ratio": training_args.warmup_ratio, "weight_decay": training_args.weight_decay, "fp16": training_args.fp16, "metric_for_best_model": training_args.metric_for_best_model, "load_best_model_at_end": training_args.load_best_model_at_end }, "adversarial_innovations": [ "FreeLB多步梯度扰动生成", "类别差异化扰动强度 (2.5倍比例)", "三阶段渐进式对抗训练", "混合训练策略 (干净+对抗样本)", "少数类F1优化目标", "动态学习率阶段调整", "强化版Focal Loss (gamma=4.0)", "V100大显存充分利用", "实时对抗效果监控", "类别平衡采样策略" ], "expected_improvements": [ "少数类召回率提升 20-35%", "少数类F1分数提升 15-30%", "模型鲁棒性显著增强", "对输入扰动的稳定性提升", "泛化能力改善", "整体准确率维持或轻微提升" ], "paths": { "base_model_path": base_model_path, "model_output_path": output_dir, "checkpoint_output_path": checkpoint_dir, "training_logs": checkpoint_dir }, "visualization_files": { "comprehensive_curves": "comprehensive_adversarial_training_curves.png", "key_metrics": "key_adversarial_metrics.png", "adversarial_analysis": [f"adversarial_analysis_epoch_{i}.png" for i in range(5, 36, 5)], "training_history": "adversarial_training_history.json" }, "training_completed": end_time.isoformat() } with open(os.path.join(checkpoint_dir, 'adversarial_training_info.json'), 'w', encoding='utf-8') as f: json.dump(adversarial_info, f, ensure_ascii=False, indent=2) # 在模型目录保存对抗训练摘要 adversarial_summary = { "adversarial_model_info": { "best_minority_f1": float(loss_tracker.best_minority_f1), "best_overall_accuracy": float(loss_tracker.best_eval_accuracy), "best_epoch": int(loss_tracker.best_epoch), "training_method": "FreeLB_Adversarial_Training", "selection_criterion": "minority_f1_score" }, "adversarial_config": adversarial_info } with open(os.path.join(output_dir, 'adversarial_model_info.json'), 'w', encoding='utf-8') as f: json.dump(adversarial_summary, f, ensure_ascii=False, indent=2) logger.info("📋 对抗训练信息和模型记录已保存") return trainer, trainer.model, tokenizer, loss_tracker, adv_validation_callback def main(): """主函数 - 完整的对抗训练流程""" logger.info("=" * 120) logger.info("🚀 Chinese-RoBERTa-WWM-Ext FreeLB对抗训练 (存储优化版)") logger.info("=" * 120) # 配置路径 train_file = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data/train_dataset.json" model_path = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model" base_model_path = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train" # 预训练模型路径 output_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_adversarial" checkpoint_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/output_adversarial" # 检查磁盘空间 import shutil total, used, free = shutil.disk_usage("/root/autodl-tmp") free_gb = free // (1024**3) logger.info(f"📁 磁盘空间检查:") logger.info(f" 🔹 可用空间: {free_gb} GB") if free_gb < 5: logger.warning("⚠️ 磁盘空间不足,建议清理后再运行训练") logger.info("💡 建议执行以下命令清理空间:") logger.info(" rm -rf /root/.cache/*") logger.info(" rm -rf /tmp/*") logger.info(" find /root/autodl-tmp -name 'checkpoint-*' -type d | head -10 | xargs rm -rf") return # 确保所有输出目录存在 os.makedirs(output_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True) logger.info(f"📁 确保输出目录存在:") logger.info(f" 🔹 对抗训练模型输出: {output_dir}") logger.info(f" 🔹 对抗训练记录: {checkpoint_dir}") logger.info(f"\n📋 FreeLB对抗训练配置 (存储优化版):") logger.info(f" 🔹 训练数据: {train_file}") logger.info(f" 🔹 基础模型: {model_path}") logger.info(f" 🔹 预训练模型: {base_model_path}") logger.info(f" 🔹 对抗训练方法: FreeLB (温和版本)") logger.info(f" 🔹 存储优化: 每10轮保存,仅保留1个checkpoint") logger.info(f" 🔹 硬件: V100 48GB GPU优化") logger.info(f"\n🎯 温和对抗训练特性:") logger.info(f" 🔹 温和扰动强度 (多数类0.05, 少数类0.1)") logger.info(f" 🔹 减少对抗步数 (3步)") logger.info(f" 🔹 延长适应期 (前10轮)") logger.info(f" 🔹 更多干净样本混合") logger.info(f" 🔹 存储空间最小化") # 加载和划分数据 train_data, val_data, imbalance_ratio = load_and_split_data(train_file, validation_split=0.2, random_state=42) if train_data is None or val_data is None: logger.error("❌ 无法加载和划分数据,程序退出") return logger.info(f"📊 数据不平衡分析: {imbalance_ratio:.1f}:1") try: # 执行完整的对抗训练 trainer, best_model, tokenizer, loss_tracker, adv_callback = train_adversarial_roberta_model( train_data, val_data, imbalance_ratio, base_model_path=base_model_path, model_path=model_path, output_dir=output_dir, checkpoint_dir=checkpoint_dir ) logger.info("=" * 120) logger.info("🎉 FreeLB温和对抗训练完成!") logger.info("=" * 120) logger.info(f"🏆 最佳对抗训练模型信息:") logger.info(f" 🔹 少数类F1分数: {loss_tracker.best_minority_f1:.4f}") logger.info(f" 🔹 整体准确率: {loss_tracker.best_eval_accuracy:.4f}") logger.info(f" 🔹 最佳模型来自: Epoch {loss_tracker.best_epoch}") logger.info(f" 🔹 选择标准: 整体准确率最高") # 计算改进效果 baseline_acc = 0.9451 # 原始模型准确率 if loss_tracker.best_eval_accuracy > 0: acc_improvement = ((loss_tracker.best_eval_accuracy - baseline_acc) / baseline_acc * 100) logger.info(f" 🔹 准确率改进: {acc_improvement:+.2f}%") logger.info(f"\n📁 文件输出位置:") logger.info(f" 🔹 最佳对抗训练模型: {output_dir}") logger.info(f" 🔹 训练记录: {checkpoint_dir}") logger.info("🔥 温和FreeLB对抗训练的核心优势:") logger.info(" ✅ 平稳的性能提升过程") logger.info(" ✅ 温和但有效的鲁棒性增强") logger.info(" ✅ 避免性能大幅下降") logger.info(" ✅ 存储空间最小化") logger.info(" ✅ 适合有限资源环境") logger.info(f"\n🎯 温和对抗训练完成,模型已优化并保存!") except Exception as e: logger.error(f"❌ 对抗训练过程中出现错误: {str(e)}") import traceback traceback.print_exc() raise if __name__ == "__main__": main()