|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix |
|
|
from sklearn.model_selection import train_test_split |
|
|
from transformers import ( |
|
|
BertTokenizer, |
|
|
BertForSequenceClassification, |
|
|
BertModel, |
|
|
BertConfig, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
DataCollatorWithPadding, |
|
|
TrainerCallback, |
|
|
EarlyStoppingCallback |
|
|
) |
|
|
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
|
|
import logging |
|
|
import os |
|
|
from datetime import datetime |
|
|
import math |
|
|
import json |
|
|
from collections import defaultdict, Counter |
|
|
import copy |
|
|
|
|
|
# 禁用wandb和其他第三方报告工具 |
|
|
os.environ["WANDB_DISABLED"] = "true" |
|
|
os.environ["COMET_MODE"] = "disabled" |
|
|
os.environ["NEPTUNE_MODE"] = "disabled" |
|
|
|
|
|
# 设置日志 |
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
# 设置matplotlib中文字体 |
|
|
plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] |
|
|
plt.rcParams['axes.unicode_minus'] = False |
|
|
|
|
|
|
|
|
def check_gpu_availability(): |
|
|
"""检查GPU可用性 - V100 48GB优化版""" |
|
|
if not torch.cuda.is_available(): |
|
|
raise RuntimeError("❌ GPU不可用!此脚本需要GPU支持。") |
|
|
|
|
|
gpu_count = torch.cuda.device_count() |
|
|
gpu_name = torch.cuda.get_device_name(0) |
|
|
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024 ** 3 |
|
|
|
|
|
logger.info(f"✅ GPU检查通过!") |
|
|
logger.info(f" 🔹 可用GPU数量: {gpu_count}") |
|
|
logger.info(f" 🔹 GPU型号: {gpu_name}") |
|
|
logger.info(f" 🔹 GPU内存: {gpu_memory:.1f} GB") |
|
|
|
|
|
# V100优化设置 + 对抗训练优化 |
|
|
torch.cuda.empty_cache() |
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cudnn.deterministic = False # 对抗训练时允许非确定性以提升性能 |
|
|
|
|
|
# 对抗训练内存优化 |
|
|
if gpu_memory >= 40: # V100 48GB |
|
|
logger.info("🚀 检测到V100大显存GPU,启用对抗训练优化配置") |
|
|
return True, gpu_memory, "v100_48gb" |
|
|
else: |
|
|
logger.info("⚠️ 检测到较小显存GPU,将使用保守的对抗训练配置") |
|
|
return True, gpu_memory, "standard" |
|
|
|
|
|
|
|
|
class AdversarialConfig: |
|
|
"""对抗训练配置类 - 温和的扰动强度设置""" |
|
|
|
|
|
def __init__(self, imbalance_ratio=6.0, gpu_type="v100_48gb"): |
|
|
self.imbalance_ratio = imbalance_ratio |
|
|
self.gpu_type = gpu_type |
|
|
|
|
|
# 使用更温和的扰动强度 |
|
|
self.majority_class_eps = 0.05 # 降低多数类扰动 |
|
|
self.minority_class_eps = 0.1 # 降低少数类扰动,仍保持2倍差异 |
|
|
self.eps_ratio = 2.0 |
|
|
|
|
|
# 使用原始的Focal Loss参数,不做修改 |
|
|
self.focal_alpha = [0.1, 0.9] # 保持原始权重 |
|
|
self.focal_gamma = 3.0 # 保持原始gamma值 |
|
|
|
|
|
# 更温和的FreeLB参数 |
|
|
self.adv_steps = 3 # 减少对抗步数 |
|
|
self.adv_lr = 0.01 # 降低对抗学习率 |
|
|
self.norm_type = "l2" |
|
|
self.max_norm = 0.5 # 降低最大范数约束 |
|
|
|
|
|
# 温和的训练阶段参数 |
|
|
self.phase_configs = { |
|
|
'adaptation': { # 适应期 (1-10 epochs) - 延长适应期 |
|
|
'eps': {'majority': 0.03, 'minority': 0.06}, # 非常温和的开始 |
|
|
'clean_ratio': 0.9, # 更多干净样本 |
|
|
'learning_rate': 2e-5, # 使用原始学习率 |
|
|
'adv_steps': 2 # 最少的对抗步数 |
|
|
}, |
|
|
'enhancement': { # 强化期 (11-25 epochs) |
|
|
'eps': {'majority': 0.05, 'minority': 0.1}, # 温和提升 |
|
|
'clean_ratio': 0.7, # 仍然更多干净样本 |
|
|
'learning_rate': 1.5e-5, # 温和降低 |
|
|
'adv_steps': 3 |
|
|
}, |
|
|
'refinement': { # 精调期 (26+ epochs) |
|
|
'eps': {'majority': 0.04, 'minority': 0.08}, # 回调扰动 |
|
|
'clean_ratio': 0.8, # 精调时更多干净样本 |
|
|
'learning_rate': 1e-5, # 精调阶段 |
|
|
'adv_steps': 2 |
|
|
} |
|
|
} |
|
|
|
|
|
# 更温和的类别差异化采样 |
|
|
self.sampling_configs = { |
|
|
'adaptation': {'majority': 0.8, 'minority': 0.2}, # 接近原始分布 |
|
|
'enhancement': {'majority': 0.7, 'minority': 0.3}, # 温和调整 |
|
|
'refinement': {'majority': 0.6, 'minority': 0.4} # 最终稍微平衡 |
|
|
} |
|
|
|
|
|
logger.info(f"🎯 温和对抗训练配置初始化完成:") |
|
|
logger.info(f" 🔹 数据不平衡比例: {imbalance_ratio}:1") |
|
|
logger.info(f" 🔹 Focal Loss权重: {self.focal_alpha} (原始设置)") |
|
|
logger.info(f" 🔹 Focal Loss gamma: {self.focal_gamma} (原始设置)") |
|
|
logger.info(f" 🔹 温和扰动: 多数类{self.majority_class_eps} vs 少数类{self.minority_class_eps}") |
|
|
logger.info(f" 🔹 对抗步数: {self.adv_steps} (温和设置)") |
|
|
logger.info(f" 🔹 对抗学习率: {self.adv_lr} (温和设置)") |
|
|
logger.info(f" 🔹 GPU配置: {gpu_type}") |
|
|
|
|
|
|
|
|
class AdversarialPerturbationGenerator: |
|
|
"""FreeLB对抗扰动生成器""" |
|
|
|
|
|
def __init__(self, config: AdversarialConfig): |
|
|
self.config = config |
|
|
|
|
|
def generate_perturbation(self, model, inputs, labels, current_phase='enhancement'): |
|
|
"""生成对抗扰动""" |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
# 获取当前阶段配置 |
|
|
phase_config = self.config.phase_configs[current_phase] |
|
|
|
|
|
# 获取输入嵌入 |
|
|
input_ids = inputs['input_ids'].to(device) |
|
|
attention_mask = inputs['attention_mask'].to(device) |
|
|
|
|
|
# 获取嵌入层 |
|
|
embedding_layer = model.roberta.embeddings.word_embeddings |
|
|
embeds = embedding_layer(input_ids) |
|
|
|
|
|
# 根据标签确定扰动强度 |
|
|
batch_eps = [] |
|
|
for label in labels: |
|
|
if label.item() == 0: # 多数类 |
|
|
batch_eps.append(phase_config['eps']['majority']) |
|
|
else: # 少数类 |
|
|
batch_eps.append(phase_config['eps']['minority']) |
|
|
|
|
|
batch_eps = torch.tensor(batch_eps, device=device).unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
# 初始化扰动 |
|
|
delta = torch.zeros_like(embeds).uniform_(-1, 1) |
|
|
delta = delta * batch_eps * 0.1 # 初始化为小扰动 |
|
|
delta.requires_grad_(True) |
|
|
|
|
|
# 多步对抗扰动生成 |
|
|
for step in range(phase_config['adv_steps']): |
|
|
# 前向传播 |
|
|
perturbed_embeds = embeds + delta |
|
|
|
|
|
# 替换嵌入层输出 |
|
|
outputs = model( |
|
|
inputs_embeds=perturbed_embeds, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels |
|
|
) |
|
|
|
|
|
loss = outputs['loss'] |
|
|
|
|
|
# 计算梯度 |
|
|
grad = torch.autograd.grad( |
|
|
loss, delta, |
|
|
retain_graph=True if step < phase_config['adv_steps'] - 1 else False |
|
|
)[0] |
|
|
|
|
|
# 归一化梯度 |
|
|
if self.config.norm_type == "l2": |
|
|
grad_norm = torch.norm(grad, dim=-1, keepdim=True) |
|
|
grad = grad / (grad_norm + 1e-8) |
|
|
else: # linf |
|
|
grad = grad.sign() |
|
|
|
|
|
# 更新扰动 |
|
|
delta = delta + self.config.adv_lr * grad |
|
|
|
|
|
# 投影到约束球 |
|
|
if self.config.norm_type == "l2": |
|
|
delta_norm = torch.norm(delta, dim=-1, keepdim=True) |
|
|
delta = delta / torch.max(delta_norm / batch_eps, torch.ones_like(delta_norm)) |
|
|
else: |
|
|
delta = torch.clamp(delta, -batch_eps, batch_eps) |
|
|
|
|
|
delta = delta.detach() |
|
|
delta.requires_grad_(True) |
|
|
|
|
|
return embeds + delta.detach() |
|
|
|
|
|
|
|
|
class LossTracker(TrainerCallback): |
|
|
"""增强的损失跟踪回调类 - 对抗训练版""" |
|
|
|
|
|
def __init__(self): |
|
|
self.train_losses = [] |
|
|
self.eval_losses = [] |
|
|
self.eval_accuracies = [] |
|
|
self.eval_minority_f1 = [] |
|
|
self.train_steps = [] |
|
|
self.eval_steps = [] |
|
|
self.epochs = [] |
|
|
self.current_epoch = 0 |
|
|
self.best_eval_accuracy = 0.0 |
|
|
self.best_minority_f1 = 0.0 |
|
|
self.best_epoch = 0 |
|
|
|
|
|
# 对抗训练特有指标 |
|
|
self.adversarial_losses = [] |
|
|
self.clean_losses = [] |
|
|
self.perturbation_norms = [] |
|
|
|
|
|
def on_log(self, args, state, control, logs=None, **kwargs): |
|
|
if logs: |
|
|
if 'loss' in logs: |
|
|
self.train_losses.append(logs['loss']) |
|
|
self.train_steps.append(state.global_step) |
|
|
if 'eval_loss' in logs: |
|
|
self.eval_losses.append(logs['eval_loss']) |
|
|
self.eval_steps.append(state.global_step) |
|
|
if 'eval_accuracy' in logs: |
|
|
self.eval_accuracies.append(logs['eval_accuracy']) |
|
|
if 'eval_minority_f1' in logs: |
|
|
self.eval_minority_f1.append(logs['eval_minority_f1']) |
|
|
# 更新最佳少数类F1 |
|
|
if logs['eval_minority_f1'] > self.best_minority_f1: |
|
|
self.best_minority_f1 = logs['eval_minority_f1'] |
|
|
self.best_epoch = self.current_epoch |
|
|
if 'eval_accuracy' in logs: |
|
|
if logs['eval_accuracy'] > self.best_eval_accuracy: |
|
|
self.best_eval_accuracy = logs['eval_accuracy'] |
|
|
|
|
|
def on_epoch_end(self, args, state, control, **kwargs): |
|
|
self.current_epoch = state.epoch |
|
|
self.epochs.append(state.epoch) |
|
|
|
|
|
|
|
|
class AdversarialValidationCallback(TrainerCallback): |
|
|
"""对抗训练验证回调 - 每5个epoch生成详细分析""" |
|
|
|
|
|
def __init__(self, eval_dataset, tokenizer, output_dir, adv_config, epochs_interval=5): |
|
|
self.eval_dataset = eval_dataset |
|
|
self.tokenizer = tokenizer |
|
|
self.output_dir = output_dir |
|
|
self.adv_config = adv_config |
|
|
self.epochs_interval = epochs_interval |
|
|
self.confusion_matrices = {} |
|
|
self.detailed_metrics = {} |
|
|
|
|
|
def on_epoch_end(self, args, state, control, model=None, **kwargs): |
|
|
current_epoch = int(state.epoch) |
|
|
|
|
|
if current_epoch % self.epochs_interval == 0 or current_epoch == args.num_train_epochs: |
|
|
logger.info(f"📊 生成第{current_epoch}轮对抗训练验证分析...") |
|
|
|
|
|
model.eval() |
|
|
predictions = [] |
|
|
true_labels = [] |
|
|
prediction_probs = [] |
|
|
|
|
|
device = next(model.parameters()).device |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(len(self.eval_dataset)): |
|
|
item = self.eval_dataset[i] |
|
|
input_ids = item['input_ids'].unsqueeze(0).to(device) |
|
|
attention_mask = item['attention_mask'].unsqueeze(0).to(device) |
|
|
|
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
|
logits = outputs['logits'] |
|
|
probs = F.softmax(logits, dim=-1) |
|
|
pred = torch.argmax(logits, dim=-1).cpu().item() |
|
|
|
|
|
predictions.append(pred) |
|
|
true_labels.append(item['labels'].item()) |
|
|
prediction_probs.append(probs.cpu().numpy()) |
|
|
|
|
|
# 计算详细指标 |
|
|
cm = confusion_matrix(true_labels, predictions) |
|
|
self.confusion_matrices[current_epoch] = cm |
|
|
|
|
|
# 计算类别特定指标 |
|
|
if cm.shape == (2, 2): |
|
|
tn, fp, fn, tp = cm.ravel() |
|
|
else: |
|
|
# 处理可能的单类别或异常情况 |
|
|
logger.warning(f"Epoch {current_epoch}: 混淆矩阵形状异常: {cm.shape}") |
|
|
if len(np.unique(true_labels)) == 1: |
|
|
# 只有一个类别的情况 |
|
|
if true_labels[0] == 0: |
|
|
tn, fp, fn, tp = len(true_labels), 0, 0, 0 |
|
|
else: |
|
|
tn, fp, fn, tp = 0, 0, 0, len(true_labels) |
|
|
else: |
|
|
# 其他异常情况 |
|
|
tn, fp, fn, tp = 0, 0, 0, 0 |
|
|
|
|
|
# 少数类(1)指标 |
|
|
minority_precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
|
|
minority_recall = tp / (tp + fn) if (tp + fn) > 0 else 0 |
|
|
minority_f1 = 2 * (minority_precision * minority_recall) / (minority_precision + minority_recall) if (minority_precision + minority_recall) > 0 else 0 |
|
|
|
|
|
# 多数类(0)指标 |
|
|
majority_precision = tn / (tn + fn) if (tn + fn) > 0 else 0 |
|
|
majority_recall = tn / (tn + fp) if (tn + fp) > 0 else 0 |
|
|
majority_f1 = 2 * (majority_precision * majority_recall) / (majority_precision + majority_recall) if (majority_precision + majority_recall) > 0 else 0 |
|
|
|
|
|
overall_accuracy = (tp + tn) / (tp + tn + fp + fn) |
|
|
balanced_accuracy = (minority_recall + majority_recall) / 2 |
|
|
|
|
|
metrics = { |
|
|
'overall_accuracy': overall_accuracy, |
|
|
'balanced_accuracy': balanced_accuracy, |
|
|
'minority_precision': minority_precision, |
|
|
'minority_recall': minority_recall, |
|
|
'minority_f1': minority_f1, |
|
|
'majority_precision': majority_precision, |
|
|
'majority_recall': majority_recall, |
|
|
'majority_f1': majority_f1 |
|
|
} |
|
|
|
|
|
self.detailed_metrics[current_epoch] = metrics |
|
|
|
|
|
logger.info(f" 📈 Epoch {current_epoch} 详细指标:") |
|
|
logger.info(f" 总体准确率: {overall_accuracy:.4f}") |
|
|
logger.info(f" 平衡准确率: {balanced_accuracy:.4f}") |
|
|
logger.info(f" 少数类F1: {minority_f1:.4f}") |
|
|
logger.info(f" 多数类F1: {majority_f1:.4f}") |
|
|
|
|
|
self.save_adversarial_analysis(cm, metrics, current_epoch) |
|
|
model.train() |
|
|
|
|
|
def save_adversarial_analysis(self, cm, metrics, epoch): |
|
|
"""保存对抗训练分析图表""" |
|
|
fig, axes = plt.subplots(2, 2, figsize=(15, 12)) |
|
|
|
|
|
# 混淆矩阵 |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0,0], |
|
|
xticklabels=['Same Paragraph (0)', 'Different Paragraph (1)'], |
|
|
yticklabels=['Same Paragraph (0)', 'Different Paragraph (1)']) |
|
|
axes[0,0].set_title(f'Adversarial Validation CM - Epoch {epoch}') |
|
|
axes[0,0].set_xlabel('Predicted Label') |
|
|
axes[0,0].set_ylabel('True Label') |
|
|
|
|
|
# 类别性能对比 |
|
|
categories = ['Majority Class', 'Minority Class'] |
|
|
precision_scores = [metrics['majority_precision'], metrics['minority_precision']] |
|
|
recall_scores = [metrics['majority_recall'], metrics['minority_recall']] |
|
|
f1_scores = [metrics['majority_f1'], metrics['minority_f1']] |
|
|
|
|
|
x = np.arange(len(categories)) |
|
|
width = 0.25 |
|
|
|
|
|
axes[0,1].bar(x - width, precision_scores, width, label='Precision', alpha=0.8) |
|
|
axes[0,1].bar(x, recall_scores, width, label='Recall', alpha=0.8) |
|
|
axes[0,1].bar(x + width, f1_scores, width, label='F1-Score', alpha=0.8) |
|
|
|
|
|
axes[0,1].set_xlabel('Class') |
|
|
axes[0,1].set_ylabel('Score') |
|
|
axes[0,1].set_title(f'Class Performance - Epoch {epoch}') |
|
|
axes[0,1].set_xticks(x) |
|
|
axes[0,1].set_xticklabels(categories) |
|
|
axes[0,1].legend() |
|
|
axes[0,1].set_ylim(0, 1) |
|
|
|
|
|
# 准确率趋势 |
|
|
if len(self.detailed_metrics) > 1: |
|
|
epochs_list = sorted(self.detailed_metrics.keys()) |
|
|
overall_acc = [self.detailed_metrics[e]['overall_accuracy'] for e in epochs_list] |
|
|
balanced_acc = [self.detailed_metrics[e]['balanced_accuracy'] for e in epochs_list] |
|
|
minority_f1_trend = [self.detailed_metrics[e]['minority_f1'] for e in epochs_list] |
|
|
|
|
|
axes[1,0].plot(epochs_list, overall_acc, 'b-o', label='Overall Accuracy', linewidth=2) |
|
|
axes[1,0].plot(epochs_list, balanced_acc, 'g-o', label='Balanced Accuracy', linewidth=2) |
|
|
axes[1,0].plot(epochs_list, minority_f1_trend, 'r-o', label='Minority F1', linewidth=2) |
|
|
|
|
|
axes[1,0].set_xlabel('Epoch') |
|
|
axes[1,0].set_ylabel('Score') |
|
|
axes[1,0].set_title('Adversarial Training Progress') |
|
|
axes[1,0].legend() |
|
|
axes[1,0].grid(True, alpha=0.3) |
|
|
|
|
|
# 类别不平衡可视化 |
|
|
if cm.shape == (2, 2): |
|
|
class_counts = [cm[0,0] + cm[0,1], cm[1,0] + cm[1,1]] # 真实类别数量 |
|
|
predicted_counts = [cm[0,0] + cm[1,0], cm[0,1] + cm[1,1]] # 预测类别数量 |
|
|
|
|
|
x_pos = [0, 1] |
|
|
axes[1,1].bar([x - 0.2 for x in x_pos], class_counts, 0.4, label='True Distribution', alpha=0.7) |
|
|
axes[1,1].bar([x + 0.2 for x in x_pos], predicted_counts, 0.4, label='Predicted Distribution', alpha=0.7) |
|
|
|
|
|
axes[1,1].set_xlabel('Class') |
|
|
axes[1,1].set_ylabel('Count') |
|
|
axes[1,1].set_title('Class Distribution Analysis') |
|
|
axes[1,1].set_xticks(x_pos) |
|
|
axes[1,1].set_xticklabels(['Majority (0)', 'Minority (1)']) |
|
|
axes[1,1].legend() |
|
|
else: |
|
|
# 处理异常混淆矩阵的情况 |
|
|
axes[1,1].text(0.5, 0.5, f'Confusion Matrix Shape: {cm.shape}', |
|
|
ha='center', va='center', transform=axes[1,1].transAxes) |
|
|
axes[1,1].set_title('Class Distribution Analysis') |
|
|
|
|
|
# 添加性能文本 |
|
|
textstr = f'''Epoch {epoch} Performance: |
|
|
Overall Acc: {metrics['overall_accuracy']:.4f} |
|
|
Balanced Acc: {metrics['balanced_accuracy']:.4f} |
|
|
Minority F1: {metrics['minority_f1']:.4f} |
|
|
Minority Recall: {metrics['minority_recall']:.4f}''' |
|
|
|
|
|
props = dict(boxstyle='round', facecolor='wheat', alpha=0.8) |
|
|
fig.text(0.02, 0.98, textstr, transform=fig.transFigure, fontsize=10, |
|
|
verticalalignment='top', bbox=props) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
save_path = os.path.join(self.output_dir, f'adversarial_analysis_epoch_{epoch}.png') |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logger.info(f" 💾 对抗训练分析图表已保存: {save_path}") |
|
|
|
|
|
|
|
|
class SentencePairDataset(Dataset): |
|
|
"""句子对数据集类 - 对抗训练优化版""" |
|
|
|
|
|
def __init__(self, data, tokenizer, max_length=512, is_validation=False, adversarial_config=None): |
|
|
self.data = data |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.is_validation = is_validation |
|
|
self.adversarial_config = adversarial_config |
|
|
|
|
|
self.valid_data = [item for item in data if item['label'] in [0, 1]] |
|
|
dataset_type = "验证" if is_validation else "训练" |
|
|
logger.info(f"原始{dataset_type}数据: {len(data)} 条,有效数据: {len(self.valid_data)} 条") |
|
|
|
|
|
self.sentence1_list = [item['sentence1'] for item in self.valid_data] |
|
|
self.sentence2_list = [item['sentence2'] for item in self.valid_data] |
|
|
self.labels = [item['label'] for item in self.valid_data] |
|
|
|
|
|
# 对抗训练专用的类别分析 |
|
|
if not is_validation and adversarial_config: |
|
|
self.class_counts = Counter(self.labels) |
|
|
self.class_weights = self._compute_adversarial_class_weights() |
|
|
self.sample_weights = self._compute_adversarial_sample_weights() |
|
|
logger.info(f"对抗训练类别权重: {self.class_weights}") |
|
|
|
|
|
def _compute_adversarial_class_weights(self): |
|
|
"""计算对抗训练专用的类别权重 - 保持原始权重比例""" |
|
|
total_samples = len(self.labels) |
|
|
class_weights = {} |
|
|
|
|
|
# 使用与原始训练相同的权重计算方法 |
|
|
for label in [0, 1]: |
|
|
count = self.class_counts[label] |
|
|
class_weights[label] = total_samples / (2 * count) # 标准的平衡权重计算 |
|
|
|
|
|
return class_weights |
|
|
|
|
|
def _compute_adversarial_sample_weights(self): |
|
|
"""计算对抗训练专用的样本权重""" |
|
|
sample_weights = [] |
|
|
for label in self.labels: |
|
|
sample_weights.append(self.class_weights[label]) |
|
|
return torch.tensor(sample_weights, dtype=torch.float) |
|
|
|
|
|
def get_weighted_sampler(self): |
|
|
"""返回对抗训练优化的加权随机采样器""" |
|
|
if self.is_validation: |
|
|
raise ValueError("验证集不需要加权采样器") |
|
|
return WeightedRandomSampler( |
|
|
weights=self.sample_weights, |
|
|
num_samples=len(self.sample_weights), |
|
|
replacement=True |
|
|
) |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.valid_data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
sentence1 = str(self.sentence1_list[idx]) |
|
|
sentence2 = str(self.sentence2_list[idx]) |
|
|
label = self.labels[idx] |
|
|
|
|
|
encoding = self.tokenizer( |
|
|
sentence1, |
|
|
sentence2, |
|
|
truncation=True, |
|
|
padding='max_length', |
|
|
max_length=self.max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
return { |
|
|
'input_ids': encoding['input_ids'].flatten(), |
|
|
'attention_mask': encoding['attention_mask'].flatten(), |
|
|
'labels': torch.tensor(label, dtype=torch.long) |
|
|
} |
|
|
|
|
|
|
|
|
def load_and_split_data(train_file, validation_split=0.2, random_state=42): |
|
|
"""加载数据并划分训练集和验证集 - 对抗训练版""" |
|
|
try: |
|
|
with open(train_file, 'r', encoding='utf-8') as f: |
|
|
all_data = json.load(f) |
|
|
logger.info(f"成功加载原始数据: {len(all_data)} 条记录") |
|
|
|
|
|
valid_data = [item for item in all_data if item['label'] in [0, 1]] |
|
|
logger.info(f"有效数据: {len(valid_data)} 条记录") |
|
|
|
|
|
labels = [item['label'] for item in valid_data] |
|
|
train_data, val_data = train_test_split( |
|
|
valid_data, |
|
|
test_size=validation_split, |
|
|
random_state=random_state, |
|
|
stratify=labels |
|
|
) |
|
|
|
|
|
logger.info(f"对抗训练数据划分完成:") |
|
|
logger.info(f" 🔹 训练集: {len(train_data)} 条") |
|
|
logger.info(f" 🔹 验证集: {len(val_data)} 条") |
|
|
|
|
|
# 详细的类别不平衡分析 |
|
|
train_labels = [item['label'] for item in train_data] |
|
|
val_labels = [item['label'] for item in val_data] |
|
|
|
|
|
train_counts = Counter(train_labels) |
|
|
val_counts = Counter(val_labels) |
|
|
|
|
|
imbalance_ratio = train_counts[0] / train_counts[1] |
|
|
|
|
|
logger.info(f"📊 数据不平衡分析:") |
|
|
logger.info(f" 训练集 - 多数类:{train_counts[0]} 少数类:{train_counts[1]} 比例:{imbalance_ratio:.1f}:1") |
|
|
logger.info(f" 验证集 - 多数类:{val_counts[0]} 少数类:{val_counts[1]} 比例:{val_counts[0]/val_counts[1]:.1f}:1") |
|
|
|
|
|
return train_data, val_data, imbalance_ratio |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"加载和划分数据失败: {str(e)}") |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
def compute_adversarial_metrics(eval_pred): |
|
|
"""计算对抗训练专用评估指标""" |
|
|
predictions, labels = eval_pred |
|
|
|
|
|
if isinstance(predictions, (list, tuple)): |
|
|
predictions = predictions[0] |
|
|
|
|
|
if not isinstance(predictions, np.ndarray): |
|
|
predictions = np.array(predictions) |
|
|
|
|
|
if len(predictions.shape) == 3: |
|
|
predictions = predictions[:, -1, :] |
|
|
elif len(predictions.shape) == 1: |
|
|
predictions = predictions.reshape(-1, 2) |
|
|
|
|
|
if len(predictions.shape) != 2: |
|
|
predictions = predictions.reshape(-1, 2) |
|
|
|
|
|
try: |
|
|
predictions = np.argmax(predictions, axis=1) |
|
|
except Exception as e: |
|
|
logger.error(f"预测处理错误: {e}") |
|
|
predictions = np.array([np.argmax(pred) if len(pred) > 1 else 0 for pred in predictions]) |
|
|
|
|
|
# 确保predictions和labels长度一致 |
|
|
min_length = min(len(predictions), len(labels)) |
|
|
predictions = predictions[:min_length] |
|
|
labels = labels[:min_length] |
|
|
|
|
|
logger.info(f"评估样本数量: predictions={len(predictions)}, labels={len(labels)}") |
|
|
|
|
|
# 计算详细指标 |
|
|
accuracy = accuracy_score(labels, predictions) |
|
|
|
|
|
# 计算混淆矩阵,确保样本数量一致 |
|
|
try: |
|
|
cm = confusion_matrix(labels, predictions) |
|
|
if cm.shape == (2, 2): |
|
|
tn, fp, fn, tp = cm.ravel() |
|
|
else: |
|
|
# 处理可能的单类别情况 |
|
|
logger.warning(f"混淆矩阵形状异常: {cm.shape}") |
|
|
if len(np.unique(labels)) == 1: |
|
|
# 只有一个类别的情况 |
|
|
if labels[0] == 0: |
|
|
tn, fp, fn, tp = len(labels), 0, 0, 0 |
|
|
else: |
|
|
tn, fp, fn, tp = 0, 0, 0, len(labels) |
|
|
else: |
|
|
# 其他异常情况,使用默认值 |
|
|
tn, fp, fn, tp = 0, 0, 0, 0 |
|
|
except Exception as e: |
|
|
logger.error(f"混淆矩阵计算错误: {e}") |
|
|
tn, fp, fn, tp = 0, 0, 0, 0 |
|
|
|
|
|
# 少数类指标 |
|
|
minority_precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
|
|
minority_recall = tp / (tp + fn) if (tp + fn) > 0 else 0 |
|
|
minority_f1 = 2 * (minority_precision * minority_recall) / (minority_precision + minority_recall) if (minority_precision + minority_recall) > 0 else 0 |
|
|
|
|
|
# 平衡准确率 |
|
|
majority_recall = tn / (tn + fp) if (tn + fp) > 0 else 0 |
|
|
balanced_accuracy = (minority_recall + majority_recall) / 2 |
|
|
|
|
|
return { |
|
|
'accuracy': accuracy, |
|
|
'balanced_accuracy': balanced_accuracy, |
|
|
'minority_precision': minority_precision, |
|
|
'minority_recall': minority_recall, |
|
|
'minority_f1': minority_f1, |
|
|
} |
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
"""保持原始设置的Focal Loss""" |
|
|
|
|
|
def __init__(self, alpha=None, gamma=3.0, reduction='mean'): # 使用原始gamma=3.0 |
|
|
super(FocalLoss, self).__init__() |
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, inputs, targets): |
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none') |
|
|
pt = torch.exp(-ce_loss) |
|
|
|
|
|
if self.alpha is not None: |
|
|
if self.alpha.type() != inputs.data.type(): |
|
|
self.alpha = self.alpha.type_as(inputs.data) |
|
|
at = self.alpha.gather(0, targets.data.view(-1)) |
|
|
ce_loss = ce_loss * at |
|
|
|
|
|
focal_weight = (1 - pt) ** self.gamma |
|
|
focal_loss = focal_weight * ce_loss |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
return focal_loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
return focal_loss.sum() |
|
|
else: |
|
|
return focal_loss |
|
|
|
|
|
|
|
|
class ScaledDotProductAttention(nn.Module): |
|
|
"""缩放点积注意力机制 - 对抗训练优化""" |
|
|
|
|
|
def __init__(self, d_model, dropout=0.1): |
|
|
super(ScaledDotProductAttention, self).__init__() |
|
|
self.d_model = d_model |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, query, key, value, mask=None): |
|
|
batch_size, seq_len, d_model = query.size() |
|
|
|
|
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_model) |
|
|
|
|
|
if mask is not None: |
|
|
mask_value = torch.finfo(scores.dtype).min |
|
|
scores = scores.masked_fill(mask == 0, mask_value) |
|
|
|
|
|
attention_weights = F.softmax(scores, dim=-1) |
|
|
attention_weights = self.dropout(attention_weights) |
|
|
|
|
|
output = torch.matmul(attention_weights, value) |
|
|
|
|
|
return output, attention_weights |
|
|
|
|
|
|
|
|
class RoBERTaWithAdversarialTraining(nn.Module): |
|
|
"""带对抗训练的RoBERTa模型 - 保持原始参数设置""" |
|
|
|
|
|
def __init__(self, model_path, num_labels=2, dropout=0.1, |
|
|
focal_alpha=None, focal_gamma=3.0, adversarial_config=None): # 保持原始gamma=3.0 |
|
|
super(RoBERTaWithAdversarialTraining, self).__init__() |
|
|
|
|
|
self.roberta = BertModel.from_pretrained(model_path) |
|
|
self.config = self.roberta.config |
|
|
self.config.num_labels = num_labels |
|
|
self.adversarial_config = adversarial_config |
|
|
|
|
|
self.scaled_attention = ScaledDotProductAttention( |
|
|
d_model=self.config.hidden_size, |
|
|
dropout=dropout |
|
|
) |
|
|
|
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.classifier = nn.Linear(self.config.hidden_size, num_labels) |
|
|
|
|
|
self.focal_loss = FocalLoss(alpha=focal_alpha, gamma=focal_gamma) |
|
|
|
|
|
# 对抗训练组件 |
|
|
self.perturbation_generator = AdversarialPerturbationGenerator(adversarial_config) |
|
|
|
|
|
self._init_weights() |
|
|
|
|
|
self.focal_alpha = focal_alpha |
|
|
self.focal_gamma = focal_gamma |
|
|
self.current_phase = 'adaptation' |
|
|
|
|
|
def _init_weights(self): |
|
|
"""初始化新增层的权重""" |
|
|
nn.init.normal_(self.classifier.weight, std=0.02) |
|
|
nn.init.zeros_(self.classifier.bias) |
|
|
|
|
|
def set_training_phase(self, phase): |
|
|
"""设置训练阶段""" |
|
|
if self.current_phase != phase: |
|
|
self.current_phase = phase |
|
|
# 只在阶段真正改变时记录日志 |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, |
|
|
labels=None, inputs_embeds=None, adversarial_training=False): |
|
|
|
|
|
if adversarial_training and labels is not None and self.training: |
|
|
return self._adversarial_forward(input_ids, attention_mask, labels) |
|
|
else: |
|
|
return self._standard_forward(input_ids, attention_mask, token_type_ids, labels, inputs_embeds) |
|
|
|
|
|
def _standard_forward(self, input_ids, attention_mask, token_type_ids, labels, inputs_embeds): |
|
|
"""标准前向传播""" |
|
|
roberta_outputs = self.roberta( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
return_dict=True |
|
|
) |
|
|
|
|
|
sequence_output = roberta_outputs.last_hidden_state |
|
|
|
|
|
enhanced_output, attention_weights = self.scaled_attention( |
|
|
query=sequence_output, |
|
|
key=sequence_output, |
|
|
value=sequence_output, |
|
|
mask=attention_mask.unsqueeze(1) if attention_mask is not None else None |
|
|
) |
|
|
|
|
|
cls_output = enhanced_output[:, 0, :] |
|
|
cls_output = self.dropout(cls_output) |
|
|
logits = self.classifier(cls_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss = self.focal_loss(logits, labels) |
|
|
|
|
|
return { |
|
|
'loss': loss, |
|
|
'logits': logits, |
|
|
'hidden_states': enhanced_output, |
|
|
'attention_weights': attention_weights |
|
|
} |
|
|
|
|
|
def _adversarial_forward(self, input_ids, attention_mask, labels): |
|
|
"""对抗训练前向传播""" |
|
|
inputs = {'input_ids': input_ids, 'attention_mask': attention_mask} |
|
|
|
|
|
# 生成对抗扰动 |
|
|
perturbed_embeds = self.perturbation_generator.generate_perturbation( |
|
|
self, inputs, labels, self.current_phase |
|
|
) |
|
|
|
|
|
# 使用扰动后的嵌入进行前向传播 |
|
|
return self._standard_forward( |
|
|
input_ids=None, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=None, |
|
|
labels=labels, |
|
|
inputs_embeds=perturbed_embeds |
|
|
) |
|
|
|
|
|
def save_pretrained(self, save_directory): |
|
|
"""保存模型""" |
|
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
model_to_save = self.module if hasattr(self, 'module') else self |
|
|
torch.save(model_to_save.state_dict(), os.path.join(save_directory, 'pytorch_model.bin')) |
|
|
|
|
|
config_dict = { |
|
|
'model_type': 'RoBERTaWithAdversarialTraining', |
|
|
'base_model': 'chinese-roberta-wwm-ext', |
|
|
'num_labels': self.config.num_labels, |
|
|
'hidden_size': self.config.hidden_size, |
|
|
'focal_alpha': self.focal_alpha.tolist() if self.focal_alpha is not None else None, |
|
|
'focal_gamma': self.focal_gamma, |
|
|
'has_scaled_attention': True, |
|
|
'has_focal_loss': True, |
|
|
'has_adversarial_training': True, |
|
|
'adversarial_config': { |
|
|
'imbalance_ratio': self.adversarial_config.imbalance_ratio, |
|
|
'eps_ratio': self.adversarial_config.eps_ratio, |
|
|
'adv_steps': self.adversarial_config.adv_steps, |
|
|
'norm_type': self.adversarial_config.norm_type |
|
|
}, |
|
|
'optimization_level': 'adversarial_training_with_best_validation' |
|
|
} |
|
|
|
|
|
with open(os.path.join(save_directory, 'config.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(config_dict, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
|
|
|
class AdversarialTrainer(Trainer): |
|
|
"""自定义对抗训练器""" |
|
|
|
|
|
def __init__(self, adversarial_config, weighted_sampler=None, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
self.adversarial_config = adversarial_config |
|
|
self.weighted_sampler = weighted_sampler |
|
|
self.current_epoch = 0 |
|
|
self.current_phase = 'adaptation' # 初始化当前阶段 |
|
|
self.phase_transitions = {8: 'enhancement', 25: 'refinement'} |
|
|
self._phase_logged = False # 避免重复日志 |
|
|
|
|
|
def get_train_dataloader(self): |
|
|
if self.train_dataset is None: |
|
|
raise ValueError("Trainer: training requires a train_dataset.") |
|
|
|
|
|
train_dataset = self.train_dataset |
|
|
|
|
|
if self.weighted_sampler is not None: |
|
|
train_sampler = self.weighted_sampler |
|
|
else: |
|
|
train_sampler = self._get_train_sampler() |
|
|
|
|
|
return DataLoader( |
|
|
train_dataset, |
|
|
batch_size=self.args.train_batch_size, |
|
|
sampler=train_sampler, |
|
|
collate_fn=self.data_collator, |
|
|
drop_last=self.args.dataloader_drop_last, |
|
|
num_workers=self.args.dataloader_num_workers, |
|
|
pin_memory=self.args.dataloader_pin_memory, |
|
|
) |
|
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
|
|
"""自定义损失计算 - 混合对抗训练""" |
|
|
|
|
|
# 确定当前训练阶段(避免频繁调用) |
|
|
current_phase = self._get_current_phase() |
|
|
if current_phase != self.current_phase: |
|
|
self.current_phase = current_phase |
|
|
model.set_training_phase(current_phase) |
|
|
logger.info(f"🎯 切换到训练阶段: {current_phase}") |
|
|
elif not self._phase_logged: |
|
|
model.set_training_phase(current_phase) |
|
|
self._phase_logged = True |
|
|
|
|
|
# 获取阶段配置 |
|
|
phase_config = self.adversarial_config.phase_configs[current_phase] |
|
|
clean_ratio = phase_config['clean_ratio'] |
|
|
|
|
|
labels = inputs.get("labels") |
|
|
|
|
|
# 混合训练策略 |
|
|
if self.state.global_step % 1 == 0: # 每步都进行混合训练 |
|
|
batch_size = labels.size(0) |
|
|
clean_size = int(batch_size * clean_ratio) |
|
|
|
|
|
if clean_size > 0: |
|
|
# 干净样本训练 |
|
|
clean_inputs = {k: v[:clean_size] for k, v in inputs.items()} |
|
|
clean_outputs = model(**clean_inputs, adversarial_training=False) |
|
|
clean_loss = clean_outputs['loss'] |
|
|
else: |
|
|
clean_loss = 0 |
|
|
|
|
|
if clean_size < batch_size: |
|
|
# 对抗样本训练 |
|
|
adv_inputs = {k: v[clean_size:] for k, v in inputs.items()} |
|
|
adv_outputs = model(**adv_inputs, adversarial_training=True) |
|
|
adv_loss = adv_outputs['loss'] |
|
|
else: |
|
|
adv_loss = 0 |
|
|
|
|
|
# 组合损失 |
|
|
if isinstance(clean_loss, torch.Tensor) and isinstance(adv_loss, torch.Tensor): |
|
|
total_loss = clean_ratio * clean_loss + (1 - clean_ratio) * adv_loss |
|
|
outputs = clean_outputs if clean_size > 0 else adv_outputs |
|
|
elif isinstance(clean_loss, torch.Tensor): |
|
|
total_loss = clean_loss |
|
|
outputs = clean_outputs |
|
|
else: |
|
|
total_loss = adv_loss |
|
|
outputs = adv_outputs |
|
|
|
|
|
outputs['loss'] = total_loss |
|
|
else: |
|
|
# 标准训练 |
|
|
outputs = model(**inputs, adversarial_training=False) |
|
|
total_loss = outputs['loss'] |
|
|
|
|
|
return (total_loss, outputs) if return_outputs else total_loss |
|
|
|
|
|
def _get_current_phase(self): |
|
|
"""根据当前epoch确定训练阶段 - 温和对抗训练版本""" |
|
|
current_epoch = int(self.state.epoch) if self.state.epoch else 0 |
|
|
|
|
|
if current_epoch < 10: # 延长适应期 |
|
|
return 'adaptation' |
|
|
elif current_epoch < 25: |
|
|
return 'enhancement' |
|
|
else: |
|
|
return 'refinement' |
|
|
|
|
|
def on_epoch_begin(self, args, state, control, **kwargs): |
|
|
"""Epoch开始时的处理""" |
|
|
super().on_epoch_begin(args, state, control, **kwargs) |
|
|
current_epoch = int(state.epoch) |
|
|
|
|
|
# 调整阶段切换点以适应温和训练 |
|
|
phase_transitions = {10: 'enhancement', 25: 'refinement'} # 延长适应期到10轮 |
|
|
|
|
|
# 检查是否需要切换阶段 |
|
|
if current_epoch in phase_transitions: |
|
|
new_phase = phase_transitions[current_epoch] |
|
|
self.current_phase = new_phase |
|
|
self._phase_logged = False # 重置日志标志 |
|
|
logger.info(f"🔄 Epoch {current_epoch}: 切换到训练阶段 '{new_phase}' (温和对抗训练)") |
|
|
|
|
|
# 更新学习率 |
|
|
phase_config = self.adversarial_config.phase_configs[new_phase] |
|
|
new_lr = phase_config['learning_rate'] |
|
|
|
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group['lr'] = new_lr |
|
|
|
|
|
logger.info(f" 📉 学习率调整为: {new_lr}") |
|
|
|
|
|
|
|
|
def plot_adversarial_training_curves(loss_tracker, output_dir): |
|
|
"""绘制对抗训练专用的可视化图表""" |
|
|
plt.figure(figsize=(20, 15)) |
|
|
|
|
|
# 1. 训练损失趋势 |
|
|
if loss_tracker.train_losses: |
|
|
plt.subplot(3, 3, 1) |
|
|
plt.plot(loss_tracker.train_steps, loss_tracker.train_losses, |
|
|
'b-', label='Training Loss', linewidth=2, alpha=0.8) |
|
|
plt.title('Adversarial Training Loss', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Loss Value') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 2. 验证损失vs验证准确率 |
|
|
if loss_tracker.eval_losses and loss_tracker.eval_accuracies: |
|
|
plt.subplot(3, 3, 2) |
|
|
ax1 = plt.gca() |
|
|
ax2 = ax1.twinx() |
|
|
|
|
|
line1 = ax1.plot(loss_tracker.eval_steps, loss_tracker.eval_losses, |
|
|
'r-', label='Validation Loss', linewidth=2) |
|
|
line2 = ax2.plot(loss_tracker.eval_steps, loss_tracker.eval_accuracies, |
|
|
'g-', label='Validation Accuracy', linewidth=2) |
|
|
|
|
|
ax1.set_xlabel('Training Steps') |
|
|
ax1.set_ylabel('Loss', color='r') |
|
|
ax2.set_ylabel('Accuracy', color='g') |
|
|
ax1.set_title('Validation Metrics') |
|
|
|
|
|
lines = line1 + line2 |
|
|
labels = [l.get_label() for l in lines] |
|
|
ax1.legend(lines, labels, loc='center right') |
|
|
ax1.grid(True, alpha=0.3) |
|
|
|
|
|
# 3. 少数类F1趋势 |
|
|
if loss_tracker.eval_minority_f1: |
|
|
plt.subplot(3, 3, 3) |
|
|
plt.plot(loss_tracker.eval_steps[:len(loss_tracker.eval_minority_f1)], |
|
|
loss_tracker.eval_minority_f1, |
|
|
'purple', label='Minority F1', linewidth=2, marker='o', markersize=4) |
|
|
plt.title('Minority Class F1 Score Progress', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('F1 Score') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 标记最佳F1 |
|
|
if loss_tracker.best_minority_f1 > 0: |
|
|
plt.axhline(y=loss_tracker.best_minority_f1, color='red', linestyle='--', alpha=0.7) |
|
|
plt.text(0.02, 0.98, f'Best: {loss_tracker.best_minority_f1:.4f}', |
|
|
transform=plt.gca().transAxes, fontsize=10, verticalalignment='top', |
|
|
bbox=dict(boxstyle="round,pad=0.3", facecolor="yellow", alpha=0.7)) |
|
|
|
|
|
# 4. 训练vs验证准确率对比 |
|
|
if loss_tracker.eval_accuracies: |
|
|
plt.subplot(3, 3, 4) |
|
|
plt.plot(loss_tracker.eval_steps, loss_tracker.eval_accuracies, |
|
|
'g-o', label='Validation Accuracy', linewidth=2, markersize=3) |
|
|
plt.title('Validation Accuracy Trend', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Accuracy') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 5. 损失收敛分析 |
|
|
if len(loss_tracker.train_losses) > 50: |
|
|
plt.subplot(3, 3, 5) |
|
|
|
|
|
# 计算移动平均 |
|
|
window_size = 20 |
|
|
train_ma = [] |
|
|
for i in range(window_size, len(loss_tracker.train_losses)): |
|
|
train_ma.append(np.mean(loss_tracker.train_losses[i-window_size:i])) |
|
|
|
|
|
plt.plot(loss_tracker.train_steps[window_size:], train_ma, |
|
|
'b-', label=f'Training Loss MA({window_size})', linewidth=2) |
|
|
plt.title('Loss Convergence Analysis', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Moving Average Loss') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 6. 性能提升统计 |
|
|
plt.subplot(3, 3, 6) |
|
|
if loss_tracker.eval_accuracies and loss_tracker.eval_minority_f1: |
|
|
metrics_data = { |
|
|
'Overall Accuracy': loss_tracker.eval_accuracies[-1] if loss_tracker.eval_accuracies else 0, |
|
|
'Best Overall Acc': max(loss_tracker.eval_accuracies) if loss_tracker.eval_accuracies else 0, |
|
|
'Minority F1': loss_tracker.eval_minority_f1[-1] if loss_tracker.eval_minority_f1 else 0, |
|
|
'Best Minority F1': loss_tracker.best_minority_f1 |
|
|
} |
|
|
|
|
|
metrics_names = list(metrics_data.keys()) |
|
|
metrics_values = list(metrics_data.values()) |
|
|
|
|
|
bars = plt.bar(metrics_names, metrics_values, |
|
|
color=['skyblue', 'lightblue', 'lightcoral', 'coral']) |
|
|
plt.title('Performance Summary', fontsize=14, fontweight='bold') |
|
|
plt.ylabel('Score') |
|
|
plt.xticks(rotation=45) |
|
|
|
|
|
# 添加数值标签 |
|
|
for bar, value in zip(bars, metrics_values): |
|
|
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, |
|
|
f'{value:.3f}', ha='center', va='bottom') |
|
|
|
|
|
# 7. 训练阶段可视化 |
|
|
plt.subplot(3, 3, 7) |
|
|
if loss_tracker.epochs: |
|
|
phase_colors = [] |
|
|
for epoch in loss_tracker.epochs: |
|
|
if epoch < 8: |
|
|
phase_colors.append('lightgreen') # adaptation |
|
|
elif epoch < 25: |
|
|
phase_colors.append('orange') # enhancement |
|
|
else: |
|
|
phase_colors.append('red') # refinement |
|
|
|
|
|
plt.scatter(loss_tracker.epochs, [1]*len(loss_tracker.epochs), |
|
|
c=phase_colors, s=50, alpha=0.7) |
|
|
plt.title('Training Phases Timeline', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Epoch') |
|
|
plt.ylabel('Phase') |
|
|
plt.yticks([1], ['Phases']) |
|
|
|
|
|
# 添加图例 |
|
|
from matplotlib.patches import Patch |
|
|
legend_elements = [Patch(facecolor='lightgreen', label='Adaptation (1-8)'), |
|
|
Patch(facecolor='orange', label='Enhancement (9-25)'), |
|
|
Patch(facecolor='red', label='Refinement (26+)')] |
|
|
plt.legend(handles=legend_elements, loc='upper right') |
|
|
|
|
|
# 8. 对抗训练效果对比(如果有基准数据) |
|
|
plt.subplot(3, 3, 8) |
|
|
baseline_acc = 0.9451 # 从你的混淆矩阵获得 |
|
|
current_acc = loss_tracker.eval_accuracies[-1] if loss_tracker.eval_accuracies else 0 |
|
|
improvement = ((current_acc - baseline_acc) / baseline_acc * 100) if baseline_acc > 0 else 0 |
|
|
|
|
|
categories = ['Baseline\n(No Adversarial)', 'Current\n(Adversarial)'] |
|
|
accuracies = [baseline_acc, current_acc] |
|
|
colors = ['lightgray', 'lightgreen' if improvement > 0 else 'lightcoral'] |
|
|
|
|
|
bars = plt.bar(categories, accuracies, color=colors) |
|
|
plt.title('Adversarial Training Impact', fontsize=14, fontweight='bold') |
|
|
plt.ylabel('Accuracy') |
|
|
plt.ylim(0.9, 1.0) |
|
|
|
|
|
# 添加改进百分比 |
|
|
plt.text(0.5, 0.95, f'Improvement: {improvement:+.2f}%', |
|
|
transform=plt.gca().transAxes, ha='center', |
|
|
bbox=dict(boxstyle="round,pad=0.3", |
|
|
facecolor="lightgreen" if improvement > 0 else "lightcoral", |
|
|
alpha=0.7)) |
|
|
|
|
|
# 9. 训练时间分析 |
|
|
plt.subplot(3, 3, 9) |
|
|
if loss_tracker.train_steps: |
|
|
step_intervals = [] |
|
|
for i in range(1, len(loss_tracker.train_steps)): |
|
|
step_intervals.append(loss_tracker.train_steps[i] - loss_tracker.train_steps[i-1]) |
|
|
|
|
|
if step_intervals: |
|
|
plt.hist(step_intervals, bins=20, alpha=0.7, color='skyblue', edgecolor='black') |
|
|
plt.title('Training Step Intervals', fontsize=14, fontweight='bold') |
|
|
plt.xlabel('Steps Between Logs') |
|
|
plt.ylabel('Frequency') |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
# 保存综合对抗训练图表 |
|
|
curves_path = os.path.join(output_dir, 'comprehensive_adversarial_training_curves.png') |
|
|
plt.savefig(curves_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logger.info(f"📈 对抗训练综合曲线已保存: {curves_path}") |
|
|
|
|
|
# 单独保存重要指标图表 |
|
|
plt.figure(figsize=(15, 10)) |
|
|
|
|
|
# 少数类F1和整体准确率对比 |
|
|
plt.subplot(2, 2, 1) |
|
|
if loss_tracker.eval_accuracies and loss_tracker.eval_minority_f1: |
|
|
min_len = min(len(loss_tracker.eval_accuracies), len(loss_tracker.eval_minority_f1)) |
|
|
steps = loss_tracker.eval_steps[:min_len] |
|
|
acc = loss_tracker.eval_accuracies[:min_len] |
|
|
f1 = loss_tracker.eval_minority_f1[:min_len] |
|
|
|
|
|
plt.plot(steps, acc, 'b-o', label='Overall Accuracy', linewidth=2, markersize=4) |
|
|
plt.plot(steps, f1, 'r-o', label='Minority F1', linewidth=2, markersize=4) |
|
|
plt.title('Key Metrics Comparison', fontsize=16, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Score') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 训练损失平滑曲线 |
|
|
plt.subplot(2, 2, 2) |
|
|
if len(loss_tracker.train_losses) > 10: |
|
|
# 应用高斯平滑 |
|
|
from scipy import ndimage |
|
|
smoothed_loss = ndimage.gaussian_filter1d(loss_tracker.train_losses, sigma=2) |
|
|
plt.plot(loss_tracker.train_steps, smoothed_loss, 'g-', linewidth=3, alpha=0.8) |
|
|
plt.title('Smoothed Training Loss', fontsize=16, fontweight='bold') |
|
|
plt.xlabel('Training Steps') |
|
|
plt.ylabel('Loss') |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
# 性能改进总结 |
|
|
plt.subplot(2, 2, 3) |
|
|
if loss_tracker.eval_accuracies and loss_tracker.eval_minority_f1: |
|
|
initial_acc = loss_tracker.eval_accuracies[0] |
|
|
final_acc = loss_tracker.eval_accuracies[-1] |
|
|
initial_f1 = loss_tracker.eval_minority_f1[0] if loss_tracker.eval_minority_f1 else 0 |
|
|
final_f1 = loss_tracker.eval_minority_f1[-1] if loss_tracker.eval_minority_f1 else 0 |
|
|
|
|
|
improvements = { |
|
|
'Overall Accuracy': ((final_acc - initial_acc) / initial_acc * 100) if initial_acc > 0 else 0, |
|
|
'Minority F1': ((final_f1 - initial_f1) / initial_f1 * 100) if initial_f1 > 0 else 0 |
|
|
} |
|
|
|
|
|
colors = ['green' if imp > 0 else 'red' for imp in improvements.values()] |
|
|
bars = plt.bar(improvements.keys(), improvements.values(), color=colors, alpha=0.7) |
|
|
plt.title('Performance Improvements (%)', fontsize=16, fontweight='bold') |
|
|
plt.ylabel('Improvement (%)') |
|
|
plt.axhline(y=0, color='black', linestyle='-', alpha=0.3) |
|
|
|
|
|
for bar, imp in zip(bars, improvements.values()): |
|
|
plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, |
|
|
f'{imp:+.1f}%', ha='center', va='bottom') |
|
|
|
|
|
# 训练稳定性分析 |
|
|
plt.subplot(2, 2, 4) |
|
|
if len(loss_tracker.eval_accuracies) > 5: |
|
|
# 计算准确率的标准差(稳定性指标) |
|
|
acc_std = np.std(loss_tracker.eval_accuracies[-10:]) # 最后10个评估的标准差 |
|
|
plt.text(0.5, 0.7, f'Recent Stability\n(Last 10 evals)', |
|
|
ha='center', va='center', transform=plt.gca().transAxes, |
|
|
fontsize=14, fontweight='bold') |
|
|
plt.text(0.5, 0.5, f'Std Dev: {acc_std:.4f}', |
|
|
ha='center', va='center', transform=plt.gca().transAxes, |
|
|
fontsize=12) |
|
|
|
|
|
stability_rating = "High" if acc_std < 0.01 else "Medium" if acc_std < 0.02 else "Low" |
|
|
color = "green" if stability_rating == "High" else "orange" if stability_rating == "Medium" else "red" |
|
|
|
|
|
plt.text(0.5, 0.3, f'Rating: {stability_rating}', |
|
|
ha='center', va='center', transform=plt.gca().transAxes, |
|
|
fontsize=12, color=color, fontweight='bold') |
|
|
plt.axis('off') |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
key_metrics_path = os.path.join(output_dir, 'key_adversarial_metrics.png') |
|
|
plt.savefig(key_metrics_path, dpi=300, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
logger.info(f"📈 关键对抗训练指标图表已保存: {key_metrics_path}") |
|
|
|
|
|
|
|
|
def train_adversarial_roberta_model(train_data, val_data, imbalance_ratio, |
|
|
base_model_path, |
|
|
model_path="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model", |
|
|
output_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_adversarial", |
|
|
checkpoint_dir="/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/output_adversarial"): |
|
|
"""完整的对抗训练流程 - V100 48GB优化版""" |
|
|
|
|
|
gpu_available, gpu_memory, gpu_type = check_gpu_availability() |
|
|
device = torch.device('cuda') |
|
|
logger.info(f"🚀 使用GPU设备: {device} ({gpu_type})") |
|
|
|
|
|
# 初始化对抗训练配置 |
|
|
adv_config = AdversarialConfig(imbalance_ratio=imbalance_ratio, gpu_type=gpu_type) |
|
|
|
|
|
logger.info(f"📥 基于预训练模型进行对抗训练: {base_model_path}") |
|
|
tokenizer = BertTokenizer.from_pretrained(model_path) |
|
|
|
|
|
# 创建对抗训练优化的Focal Loss |
|
|
alpha_tensor = torch.tensor(adv_config.focal_alpha, dtype=torch.float).to(device) |
|
|
|
|
|
model = RoBERTaWithAdversarialTraining( |
|
|
model_path=model_path, |
|
|
num_labels=2, |
|
|
dropout=0.1, |
|
|
focal_alpha=alpha_tensor, |
|
|
focal_gamma=adv_config.focal_gamma, |
|
|
adversarial_config=adv_config |
|
|
) |
|
|
|
|
|
# 如果有预训练的模型权重,加载它们 |
|
|
if os.path.exists(base_model_path): |
|
|
logger.info(f"🔄 加载预训练模型权重: {base_model_path}") |
|
|
try: |
|
|
checkpoint = torch.load(os.path.join(base_model_path, 'pytorch_model.bin'), map_location='cpu') |
|
|
model.load_state_dict(checkpoint, strict=False) |
|
|
logger.info("✅ 预训练模型权重加载成功") |
|
|
except Exception as e: |
|
|
logger.warning(f"⚠️ 无法加载预训练权重,使用默认初始化: {str(e)}") |
|
|
|
|
|
model = model.to(device) |
|
|
logger.info(f"✅ 对抗训练模型已加载到GPU: {next(model.parameters()).device}") |
|
|
|
|
|
total_params = sum(p.numel() for p in model.parameters()) |
|
|
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
logger.info(f"📊 对抗训练模型参数统计:") |
|
|
logger.info(f" 🔹 总参数量: {total_params:,}") |
|
|
logger.info(f" 🔹 可训练参数: {trainable_params:,}") |
|
|
|
|
|
logger.info("📋 准备对抗训练数据集...") |
|
|
train_dataset = SentencePairDataset(train_data, tokenizer, max_length=512, is_validation=False, adversarial_config=adv_config) |
|
|
val_dataset = SentencePairDataset(val_data, tokenizer, max_length=512, is_validation=True) |
|
|
weighted_sampler = train_dataset.get_weighted_sampler() |
|
|
|
|
|
logger.info(f" 🔹 训练集大小: {len(train_dataset)}") |
|
|
logger.info(f" 🔹 验证集大小: {len(val_dataset)}") |
|
|
|
|
|
# V100 48GB对抗训练优化配置 |
|
|
if gpu_type == "v100_48gb": |
|
|
batch_size = 12 # 对抗训练需要更多内存,适当降低 |
|
|
gradient_accumulation = 3 # 保持有效批次大小36 |
|
|
max_grad_norm = 1.0 |
|
|
fp16 = True |
|
|
dataloader_num_workers = 4 |
|
|
else: |
|
|
batch_size = 8 |
|
|
gradient_accumulation = 4 |
|
|
max_grad_norm = 0.5 |
|
|
fp16 = True |
|
|
dataloader_num_workers = 2 |
|
|
|
|
|
effective_batch_size = batch_size * gradient_accumulation |
|
|
|
|
|
# 确保输出目录存在 |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=checkpoint_dir, |
|
|
num_train_epochs=30, # 减少总轮次,温和对抗训练 |
|
|
per_device_train_batch_size=batch_size, |
|
|
per_device_eval_batch_size=batch_size, |
|
|
gradient_accumulation_steps=gradient_accumulation, |
|
|
eval_strategy="epoch", |
|
|
eval_steps=1, |
|
|
save_strategy="epoch", |
|
|
save_steps=10, # 改为每10个epoch保存一次 |
|
|
logging_strategy="steps", |
|
|
logging_steps=100, # 进一步减少日志频率 |
|
|
warmup_ratio=0.15, # 使用原始warmup比例 |
|
|
weight_decay=0.01, |
|
|
learning_rate=2e-5, # 使用原始学习率 |
|
|
load_best_model_at_end=True, |
|
|
remove_unused_columns=False, |
|
|
dataloader_pin_memory=True, |
|
|
fp16=fp16, |
|
|
dataloader_num_workers=dataloader_num_workers, |
|
|
group_by_length=True, |
|
|
report_to=[], |
|
|
adam_epsilon=1e-8, |
|
|
max_grad_norm=max_grad_norm, |
|
|
save_total_limit=1, # 只保留1个最新checkpoint |
|
|
skip_memory_metrics=True, |
|
|
disable_tqdm=False, |
|
|
lr_scheduler_type="cosine", |
|
|
warmup_steps=0, |
|
|
metric_for_best_model="eval_accuracy", # 使用原始的准确率指标 |
|
|
greater_is_better=True, |
|
|
) |
|
|
|
|
|
logger.info(f"🎯 对抗训练参数配置:") |
|
|
logger.info(f" 🔹 训练轮数: {training_args.num_train_epochs} (分3个阶段)") |
|
|
logger.info(f" 🔹 批次大小: {batch_size}") |
|
|
logger.info(f" 🔹 梯度累积: {gradient_accumulation}") |
|
|
logger.info(f" 🔹 有效批次大小: {effective_batch_size}") |
|
|
logger.info(f" 🔹 起始学习率: {training_args.learning_rate}") |
|
|
logger.info(f" 🔹 主要评估指标: 少数类F1分数") |
|
|
logger.info(f" 🔹 GPU优化配置: {gpu_type}") |
|
|
|
|
|
logger.info(f"🎯 对抗训练阶段设置:") |
|
|
for phase, config in adv_config.phase_configs.items(): |
|
|
epochs = "1-8" if phase == "adaptation" else "9-25" if phase == "enhancement" else "26+" |
|
|
logger.info(f" 🔹 {phase.capitalize()} ({epochs}): eps={config['eps']}, lr={config['learning_rate']}, clean_ratio={config['clean_ratio']}") |
|
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
|
|
# 创建回调函数 |
|
|
loss_tracker = LossTracker() |
|
|
adv_validation_callback = AdversarialValidationCallback( |
|
|
eval_dataset=val_dataset, |
|
|
tokenizer=tokenizer, |
|
|
output_dir=checkpoint_dir, |
|
|
adv_config=adv_config, |
|
|
epochs_interval=5 |
|
|
) |
|
|
|
|
|
trainer = AdversarialTrainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=val_dataset, |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
compute_metrics=compute_adversarial_metrics, |
|
|
callbacks=[loss_tracker, adv_validation_callback], |
|
|
adversarial_config=adv_config, |
|
|
weighted_sampler=weighted_sampler |
|
|
) |
|
|
|
|
|
logger.info("🏃♂️ 开始完整的FreeLB对抗训练...") |
|
|
logger.info("🎯 对抗训练特性:") |
|
|
logger.info(" ✅ FreeLB多步对抗扰动生成") |
|
|
logger.info(" ✅ 类别差异化扰动强度 (少数类2.5倍)") |
|
|
logger.info(" ✅ 混合训练策略 (干净+对抗样本)") |
|
|
logger.info(" ✅ 三阶段渐进式训练") |
|
|
logger.info(" ✅ 强化版Focal Loss (gamma=4.0)") |
|
|
logger.info(" ✅ 动态学习率调整") |
|
|
logger.info(" ✅ 少数类F1优化目标") |
|
|
logger.info(" ✅ V100 48GB内存优化") |
|
|
logger.info(" ✅ 实时对抗效果监控") |
|
|
|
|
|
start_time = datetime.now() |
|
|
|
|
|
try: |
|
|
trainer.train() |
|
|
|
|
|
logger.info(f"🏆 对抗训练完成!") |
|
|
logger.info(f" 🔹 最佳少数类F1: {loss_tracker.best_minority_f1:.4f}") |
|
|
logger.info(f" 🔹 最佳整体准确率: {loss_tracker.best_eval_accuracy:.4f}") |
|
|
logger.info(f" 🔹 最佳模型来自: Epoch {loss_tracker.best_epoch}") |
|
|
|
|
|
except RuntimeError as e: |
|
|
if "out of memory" in str(e).lower(): |
|
|
logger.error("❌ GPU内存不足!") |
|
|
logger.error("💡 建议减小批次大小或对抗步数") |
|
|
raise |
|
|
else: |
|
|
raise |
|
|
|
|
|
end_time = datetime.now() |
|
|
training_duration = end_time - start_time |
|
|
logger.info(f"🎉 对抗训练完成! 耗时: {training_duration}") |
|
|
|
|
|
logger.info("📈 生成对抗训练可视化图表...") |
|
|
plot_adversarial_training_curves(loss_tracker, checkpoint_dir) |
|
|
|
|
|
logger.info(f"💾 保存最佳对抗训练模型到: {output_dir}") |
|
|
|
|
|
# 保存最佳模型 |
|
|
trainer.model.save_pretrained(output_dir) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
# 保存对抗训练历史记录 |
|
|
adversarial_history = { |
|
|
'train_losses': loss_tracker.train_losses, |
|
|
'train_steps': loss_tracker.train_steps, |
|
|
'eval_losses': loss_tracker.eval_losses, |
|
|
'eval_steps': loss_tracker.eval_steps, |
|
|
'eval_accuracies': loss_tracker.eval_accuracies, |
|
|
'eval_minority_f1': loss_tracker.eval_minority_f1, |
|
|
'epochs': loss_tracker.epochs, |
|
|
'best_eval_accuracy': loss_tracker.best_eval_accuracy, |
|
|
'best_minority_f1': loss_tracker.best_minority_f1, |
|
|
'best_epoch': loss_tracker.best_epoch, |
|
|
'adversarial_config': { |
|
|
'imbalance_ratio': adv_config.imbalance_ratio, |
|
|
'focal_alpha': adv_config.focal_alpha, |
|
|
'focal_gamma': adv_config.focal_gamma, |
|
|
'phase_configs': adv_config.phase_configs, |
|
|
'sampling_configs': adv_config.sampling_configs |
|
|
} |
|
|
} |
|
|
|
|
|
with open(os.path.join(checkpoint_dir, 'adversarial_training_history.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(adversarial_history, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# 保存详细的对抗训练信息 |
|
|
adversarial_info = { |
|
|
"model_name": model_path, |
|
|
"model_type": "Chinese-RoBERTa-WWM-Ext with FreeLB Adversarial Training", |
|
|
"base_model_path": base_model_path, |
|
|
"optimization_level": "freelb_adversarial_training", |
|
|
"training_duration": str(training_duration), |
|
|
"num_train_samples": len(train_dataset), |
|
|
"num_val_samples": len(val_dataset), |
|
|
"data_imbalance_ratio": imbalance_ratio, |
|
|
|
|
|
"adversarial_training_config": { |
|
|
"method": "FreeLB (Free Large-Batch)", |
|
|
"adversarial_steps": adv_config.adv_steps, |
|
|
"perturbation_type": "embedding_level", |
|
|
"norm_type": adv_config.norm_type, |
|
|
"max_norm": adv_config.max_norm, |
|
|
"adversarial_lr": adv_config.adv_lr, |
|
|
"class_differential_perturbation": True, |
|
|
"majority_class_eps_range": "0.08-0.12", |
|
|
"minority_class_eps_range": "0.2-0.3", |
|
|
"eps_ratio": adv_config.eps_ratio |
|
|
}, |
|
|
|
|
|
"training_phases": { |
|
|
"adaptation_phase": { |
|
|
"epochs": "1-8", |
|
|
"description": "温和对抗训练适应期", |
|
|
"clean_ratio": 0.8, |
|
|
"learning_rate": "8e-6", |
|
|
"perturbation_strength": "低" |
|
|
}, |
|
|
"enhancement_phase": { |
|
|
"epochs": "9-25", |
|
|
"description": "强化对抗训练期", |
|
|
"clean_ratio": 0.5, |
|
|
"learning_rate": "5e-6", |
|
|
"perturbation_strength": "中高" |
|
|
}, |
|
|
"refinement_phase": { |
|
|
"epochs": "26+", |
|
|
"description": "精调优化期", |
|
|
"clean_ratio": 0.2, |
|
|
"learning_rate": "2e-6", |
|
|
"perturbation_strength": "中" |
|
|
} |
|
|
}, |
|
|
|
|
|
"best_model_info": { |
|
|
"best_minority_f1": float(loss_tracker.best_minority_f1), |
|
|
"best_overall_accuracy": float(loss_tracker.best_eval_accuracy), |
|
|
"best_epoch": int(loss_tracker.best_epoch), |
|
|
"model_selection_criterion": "minority_f1_score", |
|
|
"load_best_model_at_end": True |
|
|
}, |
|
|
|
|
|
"focal_loss_enhancement": { |
|
|
"alpha_weights": adv_config.focal_alpha, |
|
|
"gamma": adv_config.focal_gamma, |
|
|
"optimization": "adversarial_training_specific", |
|
|
"minority_class_focus": "aggressive_9x_weight" |
|
|
}, |
|
|
|
|
|
"performance_targets": { |
|
|
"primary_target": "minority_class_f1_improvement", |
|
|
"secondary_target": "overall_accuracy_maintenance", |
|
|
"expected_minority_f1_gain": "15-30%", |
|
|
"expected_robustness_improvement": "significant" |
|
|
}, |
|
|
|
|
|
"gpu_optimization": { |
|
|
"gpu_name": torch.cuda.get_device_name(0), |
|
|
"gpu_memory_gb": gpu_memory, |
|
|
"optimization_target": gpu_type, |
|
|
"effective_batch_size": effective_batch_size, |
|
|
"memory_optimization": "adversarial_training_specific", |
|
|
"batch_size_reduction": "12 (from 16 for memory)" |
|
|
}, |
|
|
|
|
|
"training_args": { |
|
|
"num_train_epochs": training_args.num_train_epochs, |
|
|
"per_device_train_batch_size": training_args.per_device_train_batch_size, |
|
|
"gradient_accumulation_steps": training_args.gradient_accumulation_steps, |
|
|
"learning_rate": training_args.learning_rate, |
|
|
"warmup_ratio": training_args.warmup_ratio, |
|
|
"weight_decay": training_args.weight_decay, |
|
|
"fp16": training_args.fp16, |
|
|
"metric_for_best_model": training_args.metric_for_best_model, |
|
|
"load_best_model_at_end": training_args.load_best_model_at_end |
|
|
}, |
|
|
|
|
|
"adversarial_innovations": [ |
|
|
"FreeLB多步梯度扰动生成", |
|
|
"类别差异化扰动强度 (2.5倍比例)", |
|
|
"三阶段渐进式对抗训练", |
|
|
"混合训练策略 (干净+对抗样本)", |
|
|
"少数类F1优化目标", |
|
|
"动态学习率阶段调整", |
|
|
"强化版Focal Loss (gamma=4.0)", |
|
|
"V100大显存充分利用", |
|
|
"实时对抗效果监控", |
|
|
"类别平衡采样策略" |
|
|
], |
|
|
|
|
|
"expected_improvements": [ |
|
|
"少数类召回率提升 20-35%", |
|
|
"少数类F1分数提升 15-30%", |
|
|
"模型鲁棒性显著增强", |
|
|
"对输入扰动的稳定性提升", |
|
|
"泛化能力改善", |
|
|
"整体准确率维持或轻微提升" |
|
|
], |
|
|
|
|
|
"paths": { |
|
|
"base_model_path": base_model_path, |
|
|
"model_output_path": output_dir, |
|
|
"checkpoint_output_path": checkpoint_dir, |
|
|
"training_logs": checkpoint_dir |
|
|
}, |
|
|
|
|
|
"visualization_files": { |
|
|
"comprehensive_curves": "comprehensive_adversarial_training_curves.png", |
|
|
"key_metrics": "key_adversarial_metrics.png", |
|
|
"adversarial_analysis": [f"adversarial_analysis_epoch_{i}.png" for i in range(5, 36, 5)], |
|
|
"training_history": "adversarial_training_history.json" |
|
|
}, |
|
|
|
|
|
"training_completed": end_time.isoformat() |
|
|
} |
|
|
|
|
|
with open(os.path.join(checkpoint_dir, 'adversarial_training_info.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(adversarial_info, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
# 在模型目录保存对抗训练摘要 |
|
|
adversarial_summary = { |
|
|
"adversarial_model_info": { |
|
|
"best_minority_f1": float(loss_tracker.best_minority_f1), |
|
|
"best_overall_accuracy": float(loss_tracker.best_eval_accuracy), |
|
|
"best_epoch": int(loss_tracker.best_epoch), |
|
|
"training_method": "FreeLB_Adversarial_Training", |
|
|
"selection_criterion": "minority_f1_score" |
|
|
}, |
|
|
"adversarial_config": adversarial_info |
|
|
} |
|
|
|
|
|
with open(os.path.join(output_dir, 'adversarial_model_info.json'), 'w', encoding='utf-8') as f: |
|
|
json.dump(adversarial_summary, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
logger.info("📋 对抗训练信息和模型记录已保存") |
|
|
|
|
|
return trainer, trainer.model, tokenizer, loss_tracker, adv_validation_callback |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""主函数 - 完整的对抗训练流程""" |
|
|
logger.info("=" * 120) |
|
|
logger.info("🚀 Chinese-RoBERTa-WWM-Ext FreeLB对抗训练 (存储优化版)") |
|
|
logger.info("=" * 120) |
|
|
|
|
|
# 配置路径 |
|
|
train_file = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/Data/train_dataset.json" |
|
|
model_path = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model" |
|
|
base_model_path = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_train" # 预训练模型路径 |
|
|
output_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/model_adversarial" |
|
|
checkpoint_dir = "/root/autodl-tmp/chinese-roberta-wwm-ext/chinese-roberta-wwm-ext/output_adversarial" |
|
|
|
|
|
# 检查磁盘空间 |
|
|
import shutil |
|
|
total, used, free = shutil.disk_usage("/root/autodl-tmp") |
|
|
free_gb = free // (1024**3) |
|
|
logger.info(f"📁 磁盘空间检查:") |
|
|
logger.info(f" 🔹 可用空间: {free_gb} GB") |
|
|
|
|
|
if free_gb < 5: |
|
|
logger.warning("⚠️ 磁盘空间不足,建议清理后再运行训练") |
|
|
logger.info("💡 建议执行以下命令清理空间:") |
|
|
logger.info(" rm -rf /root/.cache/*") |
|
|
logger.info(" rm -rf /tmp/*") |
|
|
logger.info(" find /root/autodl-tmp -name 'checkpoint-*' -type d | head -10 | xargs rm -rf") |
|
|
return |
|
|
|
|
|
# 确保所有输出目录存在 |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
os.makedirs(checkpoint_dir, exist_ok=True) |
|
|
logger.info(f"📁 确保输出目录存在:") |
|
|
logger.info(f" 🔹 对抗训练模型输出: {output_dir}") |
|
|
logger.info(f" 🔹 对抗训练记录: {checkpoint_dir}") |
|
|
|
|
|
logger.info(f"\n📋 FreeLB对抗训练配置 (存储优化版):") |
|
|
logger.info(f" 🔹 训练数据: {train_file}") |
|
|
logger.info(f" 🔹 基础模型: {model_path}") |
|
|
logger.info(f" 🔹 预训练模型: {base_model_path}") |
|
|
logger.info(f" 🔹 对抗训练方法: FreeLB (温和版本)") |
|
|
logger.info(f" 🔹 存储优化: 每10轮保存,仅保留1个checkpoint") |
|
|
logger.info(f" 🔹 硬件: V100 48GB GPU优化") |
|
|
|
|
|
logger.info(f"\n🎯 温和对抗训练特性:") |
|
|
logger.info(f" 🔹 温和扰动强度 (多数类0.05, 少数类0.1)") |
|
|
logger.info(f" 🔹 减少对抗步数 (3步)") |
|
|
logger.info(f" 🔹 延长适应期 (前10轮)") |
|
|
logger.info(f" 🔹 更多干净样本混合") |
|
|
logger.info(f" 🔹 存储空间最小化") |
|
|
|
|
|
# 加载和划分数据 |
|
|
train_data, val_data, imbalance_ratio = load_and_split_data(train_file, validation_split=0.2, random_state=42) |
|
|
if train_data is None or val_data is None: |
|
|
logger.error("❌ 无法加载和划分数据,程序退出") |
|
|
return |
|
|
|
|
|
logger.info(f"📊 数据不平衡分析: {imbalance_ratio:.1f}:1") |
|
|
|
|
|
try: |
|
|
# 执行完整的对抗训练 |
|
|
trainer, best_model, tokenizer, loss_tracker, adv_callback = train_adversarial_roberta_model( |
|
|
train_data, val_data, imbalance_ratio, |
|
|
base_model_path=base_model_path, |
|
|
model_path=model_path, |
|
|
output_dir=output_dir, |
|
|
checkpoint_dir=checkpoint_dir |
|
|
) |
|
|
|
|
|
logger.info("=" * 120) |
|
|
logger.info("🎉 FreeLB温和对抗训练完成!") |
|
|
logger.info("=" * 120) |
|
|
logger.info(f"🏆 最佳对抗训练模型信息:") |
|
|
logger.info(f" 🔹 少数类F1分数: {loss_tracker.best_minority_f1:.4f}") |
|
|
logger.info(f" 🔹 整体准确率: {loss_tracker.best_eval_accuracy:.4f}") |
|
|
logger.info(f" 🔹 最佳模型来自: Epoch {loss_tracker.best_epoch}") |
|
|
logger.info(f" 🔹 选择标准: 整体准确率最高") |
|
|
|
|
|
# 计算改进效果 |
|
|
baseline_acc = 0.9451 # 原始模型准确率 |
|
|
if loss_tracker.best_eval_accuracy > 0: |
|
|
acc_improvement = ((loss_tracker.best_eval_accuracy - baseline_acc) / baseline_acc * 100) |
|
|
logger.info(f" 🔹 准确率改进: {acc_improvement:+.2f}%") |
|
|
|
|
|
logger.info(f"\n📁 文件输出位置:") |
|
|
logger.info(f" 🔹 最佳对抗训练模型: {output_dir}") |
|
|
logger.info(f" 🔹 训练记录: {checkpoint_dir}") |
|
|
|
|
|
logger.info("🔥 温和FreeLB对抗训练的核心优势:") |
|
|
logger.info(" ✅ 平稳的性能提升过程") |
|
|
logger.info(" ✅ 温和但有效的鲁棒性增强") |
|
|
logger.info(" ✅ 避免性能大幅下降") |
|
|
logger.info(" ✅ 存储空间最小化") |
|
|
logger.info(" ✅ 适合有限资源环境") |
|
|
|
|
|
logger.info(f"\n🎯 温和对抗训练完成,模型已优化并保存!") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ 对抗训练过程中出现错误: {str(e)}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
raise |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |