- Title: Vulnerability-Aware Alignment: Mitigating Uneven Forgetting in Harmful Fine-Tuning
- Authors: Liang Chen, Xueting Han, Li Shen, Jing Bai, Kam-Fai Wong. Their affiliations are not explicitly listed in the header, but the acknowledgments mention CUHK (The Chinese University of Hong Kong).
- Journal/Conference: The paper is available on OpenReview, which is a platform commonly used for submissions to top-tier conferences like ICLR, NeurIPS, and ICML. The link suggests it was submitted for review, and the content format is typical of a conference paper.
- Publication Year: The paper cites sources from 2024, and the OpenReview link is active, suggesting the work is from 2024 or later.
- Abstract: The paper addresses the problem of Harmful Fine-Tuning (HFT), where the safety alignment of Large Language Models (LLMs) is compromised during fine-tuning. The authors observe that existing methods treat all alignment data equally, ignoring that certain subsets are more "vulnerable" to being forgotten. They propose Vulnerability-Aware Alignment (VAA), a method that first identifies and partitions data into "vulnerable" and "invulnerable" groups. It then uses a Group Distributionally Robust Optimization (Group DRO) framework to ensure balanced learning across these groups. VAA employs an adversarial sampler to focus on the underperforming group and applies group-specific perturbations to enhance robustness. Experiments show VAA significantly reduces harmfulness while maintaining performance on downstream tasks, outperforming existing methods.
- Original Source Link: The paper is available at https://openreview.net/forum?id=EMHED4WTHT, with the PDF at https://openreview.net/pdf?id=EMHED4WTHT. Its status appears to be under review or recently accepted at a conference.
2. Executive Summary
-
Foundational Concepts:
- Large Language Models (LLMs): These are deep learning models with billions of parameters, trained on massive text data to understand and generate human-like language (e.g., GPT-4, Llama).
- Fine-Tuning: The process of taking a pre-trained LLM and further training it on a smaller, task-specific dataset to adapt its behavior.
- Safety Alignment: A crucial step after pre-training where an LLM is trained to be helpful and harmless. This usually involves techniques like Supervised Fine-Tuning (SFT) on safety data and Reinforcement Learning from Human Feedback (RLHF) to avoid generating toxic, biased, or dangerous content.
- Harmful Fine-Tuning (HFT): A malicious or unintentional process where fine-tuning an aligned LLM erodes its safety alignment, making it susceptible to generating harmful responses. This can happen even if the fine-tuning dataset is small or contains only benign examples.
- Distributionally Robust Optimization (DRO): A machine learning optimization paradigm. Instead of minimizing the average loss over the training data (known as Empirical Risk Minimization or ERM), DRO aims to minimize the worst-case loss over a set of possible data distributions. This makes the model more robust to shifts or imbalances in the data.
- Group DRO: A specific type of DRO where the data is partitioned into known groups (e.g., based on demographics, or in this case, vulnerability). The goal is to minimize the loss of the worst-performing group, ensuring fairness and balanced performance across all groups.
-
Previous Works:
The paper categorizes HFT mitigation methods into three stages:
- Alignment-Stage Methods: These methods modify the initial safety alignment process to make the model inherently more robust before it is handed over for fine-tuning. This is the category VAA belongs to.
Vaccine
(Huang et al., 2024d): Adds perturbations to the model's hidden embeddings during alignment to reduce "embedding drift" and prevent forgetting.
RepNoise
(Rosati et al., 2024a): Makes harmful data "unlearnable" by manipulating the model's representations.
Booster
(Huang et al., 2024b): Uses a regularizer to slow down the rate at which the model learns from harmful examples.
- Fine-Tuning-Stage Methods: These methods regulate the fine-tuning process itself, for example, by adding constraints or regularizers during user-led fine-tuning.
- Post-Fine-Tuning-Stage Methods: These methods attempt to "repair" a model after it has been compromised by HFT.
-
Differentiation:
While previous alignment-stage methods like Vaccine
and Booster
focus on making the model's parameters or representations robust, they treat all alignment data points as equally important. VAA's key innovation is its data-centric approach. It is the first to systematically study the uneven vulnerability of alignment data. By identifying and explicitly modeling this vulnerability using a Group DRO framework, VAA directly tackles the root cause of why some safety knowledge is forgotten more easily than others, leading to a more targeted and effective defense.
4. Methodology (Core Technology & Implementation)
The paper's core idea is to first identify which alignment examples are easily "forgotten" during HFT and then use a robust training procedure to force the model to learn these "vulnerable" examples just as well as the "invulnerable" ones. This is done in two main stages.
Stage 1: Group Estimation by Vulnerability
The first step is to partition the alignment dataset into two groups: vulnerable
and invulnerable
.
该图像是示意图,展示了论文中对齐数据集被划分为两个子集的过程:G1(无弱点组,Invulnerable)和G2(有弱点组,vulnerable),并展示了各组在经验分布上的占比,体现了数据脆弱性分组的核心思想。
- Defining and Calculating Data Vulnerability: The paper defines vulnerability as the tendency of an alignment example to be "forgotten" during HFT. Forgetting is measured by tracking the Harmful Score (HS) of each example throughout a simulated fine-tuning process. The HS for an example is a binary value: 1 if the model's output for it is harmful, 0 otherwise. An example is considered "forgotten" at a time step if its HS increases from its initial value (e.g., from 0 to 1).
- The
ForgotNum
Metric: To quantify vulnerability, the authors introduce ForgotNum
. For each alignment example i, they count the total number of times it is forgotten over T training steps of HFT:
ForgotNumi=t=1∑T(I(HSit>HSi0))
where HSit is the harmful score of example i at step t, HSi0 is the initial score, and I(⋅) is the indicator function. A higher ForgotNum
indicates greater vulnerability.
- Proxy Fine-Tuning for Grouping: Since the actual downstream fine-tuning data is unknown, the authors use a proxy dataset (Alpaca dataset mixed with 10% harmful data) to simulate HFT. They fine-tune an aligned model on this proxy data and calculate the
ForgotNum
for every example in the original alignment set.
- Partitioning: Examples with ForgotNum > 0 are classified as vulnerable, and those with ForgotNum = 0 are classified as invulnerable. This grouping serves as a prior for the next stage.
Empirical Motivation: Data-Dependent and Transferable Forgetting
The paper provides strong empirical justification for this grouping strategy.
该图像是图表,展示了不同中毒率和多种微调任务下的遗忘行为分析。(a)部分展示了SST2任务在中毒率0%、10%、20%时,遗忘(Forget)、共同遗忘(Common Forget)、未遗忘(Unforget)、共同未遗忘(Common Unforget)四类数据所占比例;(b)部分则展示了固定10%中毒率下,SST2、GSM8K和AGNews三个任务中的遗忘模式对比。数据显示部分样本存在一致的遗忘倾向。
As shown in Figure 2, the analysis reveals:
-
Forgetting is Data-Dependent: In Figure 2(a), even as the proportion of harmful data (poison rate
) increases, a significant portion of the forgotten examples (the shaded red "Common Forget" region) is the same across all settings. This shows that certain examples are inherently prone to being forgotten.
-
Forgetting is Transferable: In Figure 2(b), when fine-tuning on completely different tasks (SST2, GSM8K, AGNews), there is still a substantial overlap in which alignment examples are forgotten. This transferability justifies using a single proxy dataset for grouping.
-
Vulnerable Data is Less Robust:
该图像是三维损失曲面图,展示了模型在不同数据组上的鲁棒性行为。左图对应易受攻击的“易损”数据,右图对应“不易损”数据。结果表明模型对“易损”数据的扰动更敏感,鲁棒性较差。
Figure 3 shows the loss landscape for vulnerable vs. invulnerable data. The landscape for vulnerable data (left) is much "sharper" and more chaotic, while for invulnerable data (right) it is "flatter." A flatter loss landscape indicates greater robustness to changes in model weights. This suggests vulnerable examples are less robustly learned during standard alignment.
Stage 2: Vulnerability-Aware Alignment (VAA)
With the data partitioned, VAA uses a Group DRO framework to ensure balanced learning. This is framed as a two-player game between an adversarial sampler and the LLM.
该图像是论文中描述Vulnerability-Aware Alignment方法的示意图,展示了包含5个步骤的训练流程:Step 1通过对抗采样获得分组权重Gi,Step 2对各组数据施加组别相关的对抗扰动εi,Step 3计算损失,Step 4执行镜像上升优化,Step 5进行梯度下降更新LLM模型参数,体现了分组鲁棒优化框架下的平衡学习过程。
-
Principles: Standard training, or Empirical Risk Minimization (ERM), optimizes the average loss across all data:
θERM:=argθ∈ΘminE(x,y)∼P^[ℓ(θ;(x,y))]
This is problematic because the invulnerable group is typically much larger, so its gradients dominate, leading to under-learning of the vulnerable group (a phenomenon called "gradient starvation").
Group DRO addresses this by minimizing the worst-case loss across the groups:
θ^DRO=argθ∈Θmin{Gi∈QsupE(x,y)∼Gi[fi(θ;(x,y))]}
where Gi is a data group (vulnerable or invulnerable) and fi(θ) is a robust objective function for that group. This forces the model to perform well on all groups, especially the one it currently finds most difficult.
-
Steps & Procedures (The Two-Player Game):
- Adversarial Sampler (Player 1): At each training step, a sampler chooses a group to train on. It maintains a probability distribution q over the groups. It adversarially updates q to sample more frequently from the group with the highest current loss (the "hard" group).
- LLM (Player 2): The LLM receives a batch of data from the group selected by the sampler. It then updates its parameters θ to minimize its loss on this challenging data.
-
Mathematical Formulas & Key Details:
1. The Robust Objective fi(θ):
To proactively handle the parameter shifts caused by HFT, VAA uses a robust objective that combines the standard loss with a perturbed loss.
fi(θ)=(1−λ)ℓi(θ)+λℓi(θ+ϵi)
- ℓi(θ): The standard loss (e.g., cross-entropy) for group Gi.
- ϵi: A group-dependent adversarial perturbation applied to the model's weights θ. It is calculated to maximize the loss for that group.
- λ: A hyperparameter that balances standard learning and robust learning. It is gradually increased from 0 to 1 during training (curriculum learning).
2. Learning the Adversarial Sampler q:
The sampler's distribution q is updated using mirror ascent, a technique for optimization on a probability simplex. This results in an update rule similar to the EXP3 algorithm from adversarial bandits:
qi(t)=Zqi(t−1)exp(ηqfi(θ(t−1)))
- qi(t): The probability of sampling group i at step t.
- fi(θ(t−1)): The loss of group i from the previous step. This acts as the "reward" signal.
- ηq: The step size for the sampler.
- Z: A normalization constant to ensure probabilities sum to 1.
This rule increases the sampling probability for groups that currently have a higher loss, forcing the LLM to focus on them.
3. Training the LLM θ:
Once a group Gi is sampled, the LLM performs a standard gradient descent step on the robust objective fi(θ) for a batch of data from that group.
This iterative process continues until convergence, at which point the LLM has learned to perform equally well on both vulnerable and invulnerable data, making its safety alignment more robust.
5. Experimental Setup
6. Results & Analysis
The paper's experiments demonstrate VAA's effectiveness across various settings. The following tables are transcribed from the paper.
Core Results
Generalization to Fine-tuning Datasets
This experiment evaluates VAA on four different tasks with a 10% poison ratio.
Table 1: Performance analysis for different fine-tuning tasks. Best in bold, second-best underlined.
This is a manual transcription of Table 1 from the paper.
Methods |
SST2 |
AGNEWS |
GSM8K |
AlpacaEval |
Average |
HS ↓ |
FA ↑ |
HS ↓ |
FA ↑ |
HS ↓ |
FA ↑ |
HS ↓ |
FA ↑ |
HS ↓ |
FA ↑ |
SFT |
32.87 |
91.00 |
33.07 |
87.40 |
41.63 |
6.80 |
30.48 |
39.73 |
34.51 |
56.23 |
RepNoise |
27.89 |
90.40 |
27.29 |
84.00 |
41.83 |
6.60 |
34.66 |
36.21 |
32.92 |
54.30 |
Vaccine |
27.69 |
89.40 |
30.28 |
85.60 |
34.66 |
6.20 |
32.47 |
38.62 |
31.28 |
54.96 |
Booster |
25.90 |
91.80 |
31.87 |
87.00 |
41.04 |
6.40 |
40.24 |
39.41 |
34.76 |
56.15 |
VAA |
20.00 |
91.00 |
21.12 |
87.40 |
31.08 |
8.60 |
27.09 |
40.06 |
24.82 |
56.77 |
- Analysis: VAA achieves the lowest (best) Harmful Score (HS) across all four diverse fine-tuning tasks, often by a large margin. On average, its HS is 24.82, significantly lower than the next best baseline (Vaccine at 31.28). Importantly, this safety improvement does not come at the cost of utility; VAA maintains or improves the Fine-tuning Accuracy (FA), achieving the highest average FA. Baselines like
RepNoise
and Booster
struggle on complex tasks like GSM8K
and AlpacaEval
, whereas VAA remains effective, showing its robustness.
Robustness to Harmful Ratio
This experiment varies the percentage of harmful data (p) in the fine-tuning set for the SST2 task.
Table 2: Performance analysis for different harmful ratio.
This is a manual transcription of Table 2 from the paper.
Methods |
Harmful Score ↓ |
Finetune Accuracy ↑ |
p=0% |
p=10% |
p=20% |
Average |
p=0% |
p=10% |
p=20% |
Average |
SFT |
23.11 |
32.87 |
38.84 |
31.61 |
91.80 |
91.00 |
90.00 |
90.93 |
RepNoise |
22.91 |
27.89 |
35.26 |
28.69 |
90.20 |
90.40 |
90.60 |
90.40 |
Vaccine |
21.31 |
27.69 |
36.65 |
28.55 |
90.40 |
89.40 |
90.00 |
89.93 |
Booster |
14.54 |
25.90 |
30.28 |
23.57 |
90.20 |
91.80 |
90.40 |
90.80 |
VAA |
12.35 |
20.00 |
25.30 |
19.22 |
90.60 |
91.00 |
91.20 |
90.93 |
- Analysis: VAA consistently provides the best protection (lowest HS) regardless of the poison ratio. A key finding is its strong performance at p=0%. This shows that VAA not only defends against malicious HFT but also mitigates the safety degradation caused by fine-tuning on purely benign data, a critical and often overlooked problem.
Robustness to Harmful Fine-tuning Epochs
This experiment shows how safety degrades as HFT continues for more epochs on the SST2 task.
Table 3: Performance analysis for different harmful fine-tuning epochs.
This is a manual transcription of Table 3 from the paper.
Methods |
Harmful Score ↓ |
Finetune Accuracy ↑ |
epoch=1 |
epoch=3 |
epoch=5 |
Average |
epoch=1 |
epoch=3 |
epoch=5 |
Average |
SFT |
27.69 |
31.67 |
32.87 |
30.74 |
90.00 |
91.00 |
91.00 |
90.67 |
RepNoise |
27.89 |
30.88 |
27.89 |
28.89 |
90.20 |
91.20 |
90.40 |
90.60 |
Vaccine |
25.30 |
29.08 |
27.69 |
27.36 |
84.00 |
88.80 |
89.40 |
87.40 |
Booster |
29.08 |
24.10 |
25.90 |
26.36 |
89.20 |
88.80 |
91.80 |
89.93 |
VAA |
14.60 |
19.20 |
20.00 |
17.93 |
90.00 |
91.40 |
91.00 |
90.80 |
- Analysis: As expected, longer fine-tuning leads to higher harmfulness for all methods. However, VAA maintains a significantly lower HS at every checkpoint, demonstrating sustained robustness over the course of HFT.
Generalization to Different LLMs
This experiment tests if VAA's benefits apply to a different model, Qwen2.5-7B
. Crucially, the vulnerable/invulnerable group assignments were derived from LLaMA2
and transferred directly without re-computation.
Table 4: Performance analysis for different harmful fine-tuning epochs on Qwen2.5-7B.
This is a manual transcription of Table 4 from the paper.
Methods |
Harmful Score ↓ |
Finetune Accuracy ↑ |
Ep1 |
Ep2 |
Ep3 |
Ep4 |
Ep5 |
Ep1 |
Ep2 |
Ep3 |
Ep4 |
Ep5 |
SFT |
26.89 |
31.47 |
31.08 |
33.27 |
33.07 |
84.80 |
76.80 |
86.80 |
87.20 |
86.40 |
RepNoise |
22.31 |
25.10 |
26.49 |
30.68 |
30.88 |
83.00 |
81.60 |
88.80 |
87.20 |
88.00 |
Vaccine |
29.88 |
29.28 |
28.88 |
30.48 |
29.48 |
82.60 |
84.20 |
83.20 |
85.40 |
85.60 |
Booster |
19.92 |
21.91 |
25.10 |
26.29 |
30.28 |
85.20 |
84.80 |
87.40 |
87.60 |
88.00 |
VAA |
17.73 |
18.33 |
20.12 |
21.91 |
22.11 |
86.20 |
86.40 |
85.40 |
87.60 |
88.60 |
- Analysis: VAA again achieves the best safety performance on a completely different model architecture. This result strongly supports the hypothesis that data vulnerability is a fundamental property that transfers across models, making the VAA approach generalizable.
Ablations / Discussion
Impact of Example Grouping and Sampling
These experiments validate VAA's core design choices on the SST2 task.
Table 5: Impact of examples grouping.
This is a manual transcription of Table 5 from the paper.
Strategy |
HS ↓ |
FA ↑ |
VAA |
20.00 |
91.00 |
- w/o group |
26.42 |
90.08 |
- w noisy group |
21.08 |
91.20 |
- Analysis: Removing the grouping (
- w/o group
) significantly worsens the HS, proving that explicitly modeling vulnerability is crucial. When 10% of the group labels are randomized (- w noisy group
), performance degrades only slightly, showing that the GDRO framework is robust to imperfect group assignments.
Table 6: Impact of sampling strategies.
This is a manual transcription of Table 6 from the paper.
Strategy |
HS ↓ |
FA ↑ |
VAA |
20.00 |
91.00 |
- Vuln. group only |
29.26 |
90.15 |
- Invuln. group only |
33.98 |
91.20 |
- Imp. sampling |
28.64 |
90.35 |
- Analysis: The dynamic, adversarial sampling of VAA is superior. Sampling only from the vulnerable group is better than only from the invulnerable group (confirming vulnerable data is more informative for robustness), but both are suboptimal. Standard importance sampling (
Imp. sampling
), a common technique for imbalanced data, is also less effective than VAA's adaptive approach.
Computational Overhead
VAA is computationally efficient. It requires 1.5x
the backpropagation steps of standard SFT, which is more efficient than Vaccine
(2x
) and Booster
(3x
), while delivering better performance.
7. Conclusion & Reflections
-
Conclusion Summary: The paper successfully identifies and addresses the problem of uneven forgetting in harmful fine-tuning. It demonstrates that certain alignment examples are consistently vulnerable and that this vulnerability is transferable. The proposed method, VAA, leverages this insight by using a Group DRO framework to promote balanced learning between vulnerable and invulnerable data. Experimental results confirm that VAA significantly enhances safety against HFT across various tasks and models without compromising utility, outperforming existing alignment-stage defenses.
-
Limitations & Future Work: The authors acknowledge several limitations:
- The data partitioning method is simple and requires a proxy fine-tuning step. Future work could explore more sophisticated, continuous measures of vulnerability that don't need this extra stage.
- VAA reduces but does not completely eliminate safety breakdowns. It could be combined with other techniques like AI watermarking for more comprehensive protection.
-
Personal Insights & Critique:
- Novelty and Impact: The paper's primary strength is its shift to a data-centric view of alignment robustness. Identifying "vulnerable" data points is a powerful and intuitive concept that moves beyond treating alignment data as a monolithic block. This perspective could inspire a new class of defense mechanisms.
- Practicality: The one-time cost of grouping data and the modest overhead of VAA training make it a practical solution for developers of open-source models or providers of fine-tuning services. The demonstrated transferability of vulnerability across models is a particularly strong result, as it means the expensive grouping step doesn't need to be repeated for every new model architecture.
- Potential Weaknesses: The definition of "vulnerable" (ForgotNum > 0) is binary and depends on a single proxy fine-tuning run. The choice of proxy dataset could influence which examples are flagged. A more nuanced, probabilistic, or multi-proxy approach might yield even more robust groupings.
- Future Directions: This work opens up interesting questions. What linguistic or semantic properties characterize vulnerable examples? Are they more complex, ambiguous, or related to topics where safety boundaries are subtle? Answering these could lead to even better methods for generating or identifying robust alignment data from the start. Furthermore, applying this "vulnerable group" concept to other domains, such as mitigating bias or improving out-of-distribution generalization, seems like a promising avenue for future research.