Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning
TL;DR Summary
This work theoretically proves the suboptimality of sequential fine-tuning and preference learning in LLMs, proposing a joint training framework with convergence guarantees that improves performance without extra cost.
Abstract
Post-training of pre-trained LLMs, which typically consists of the supervised fine-tuning (SFT) stage and the preference learning (RLHF or DPO) stage, is crucial to effective and safe LLM applications. The widely adopted approach in post-training popular open-source LLMs is to sequentially perform SFT and RLHF/DPO. However, sequential training is sub-optimal in terms of SFT and RLHF/DPO trade-off: the LLM gradually forgets about the first stage's training when undergoing the second stage's training. We theoretically prove the sub-optimality of sequential post-training. Furthermore, we propose a practical joint post-training framework with theoretical convergence guarantees and empirically outperforms sequential post-training framework, while having similar computational cost. Our code is available at https://github.com/heshandevaka/XRIGHT.
English Analysis
1. Bibliographic Information
- Title: Mitigating Forgetting in LLM Supervised Fine-Tuning and Preference Learning
- Authors: Heshan Fernando, Han Shen, Parikshit Ram, Yi Zhou, Horst Samulowitz, Nathalie Baracaldo, Tianyi Chen.
- Affiliations: The authors are affiliated with Rensselaer Polytechnic Institute (1) and IBM Research (2).
- Journal/Conference: The paper is a preprint available on arXiv. As of its version v3, it has not yet been published in a peer-reviewed conference or journal, which is common for fast-moving fields like ML.
- Publication Year: 2024 (First version submitted in October).
- Abstract: The paper addresses a key issue in the post-training of Large Language Models (LLMs), which involves two main stages: Supervised Fine-Tuning (SFT) and preference learning (like RLHF or DPO). The standard method is to perform these stages sequentially. The authors argue this is suboptimal because the model "forgets" the learning from the first stage while training on the second. They provide a theoretical proof for this sub-optimality. To solve this, they propose a joint post-training framework with two practical algorithms,
ALRIGHT
andMAXRIGHT
. These methods are shown to achieve a better trade-off between SFT and preference learning performance, have theoretical convergence guarantees, and maintain a similar computational cost to the sequential method. - Original Source Link: https://arxiv.org/abs/2410.15483v3
- The paper is currently available as a preprint on arXiv.
2. Executive Summary
-
Background & Motivation (Why):
- Modern LLMs are made useful and safe through a post-training process that typically includes two distinct steps: Supervised Fine-Tuning (SFT) to teach the model how to follow instructions, and preference learning (e.g., DPO) to align it with human values and preferences.
- The common practice is to perform these steps sequentially: first SFT, then DPO, or vice-versa. The paper identifies a critical flaw in this approach: catastrophic forgetting. When the model is fine-tuned on the second task's data (e.g., DPO), its performance on the first task (e.g., SFT) degrades. This results in a suboptimal trade-off between instruction-following ability and preference alignment.
- A naive solution would be to mix the SFT and DPO objectives and train them simultaneously. However, this is computationally very expensive, often prohibitively so for LLMs, as it requires constructing and managing two separate computation graphs in memory.
- This leaves a significant gap: a need for a post-training method that can effectively balance both SFT and DPO objectives without the performance loss of sequential training or the high computational cost of naive mixing.
-
Main Contributions / Findings (What): The paper makes three primary contributions:
- Theoretical and Empirical Insight into Forgetting: The authors are the first to theoretically prove that sequential post-training leads to a non-diminishing optimality gap. This means that no matter how long you train, the final model will always be a fixed, suboptimal distance away from the ideal balance between SFT and DPO objectives. They also provide empirical evidence to support this claim.
- Principled and Efficient Post-Training Algorithms: They propose two novel algorithms that jointly train SFT and DPO objectives with minimal extra computational cost:
- ALRIGHT (
ALternating supeRvised fInetuninG and Human preference alignmenT
): This method alternates between SFT and DPO gradient updates based on a fixed probability. It is simple, efficient, and proven to converge to any desired trade-off point. - MAXRIGHT (
MAXimum supeRvised fIne-tuninG and Human preference alignmenT
): This is an adaptive version of ALRIGHT. At each step, it intelligently chooses to update the objective (SFT or DPO) that is currently performing worse, thereby focusing efforts where they are most needed.
- ALRIGHT (
- Strong Empirical Validation: Experiments on
PYTHIA-1B
andLLAMA3-8B
models show thatALRIGHT
andMAXRIGHT
significantly outperform the standard sequential method. For instance, onLLAMA3-8B
, they achieve up to a 3% improvement on the MMLU benchmark and a 31% increase in win rate against a baseline, all while having computational costs comparable to the sequential approach and far lower than naive mixing.
3. Prerequisite Knowledge & Related Work
-
Foundational Concepts:
- Large Language Models (LLMs): These are massive neural networks (e.g., GPT-4, Llama 3) trained on vast amounts of text data. The initial training is called "pre-training," which gives them general language understanding. To be useful for specific tasks, they undergo "post-training."
- Supervised Fine-Tuning (SFT): This is the first step in post-training. The pre-trained LLM is further trained on a curated dataset of high-quality instruction-response pairs (e.g., "Question: What is the capital of France?", "Answer: The capital of France is Paris."). This teaches the model to follow instructions and generate helpful responses.
- Reinforcement Learning from Human Feedback (RLHF): A technique to align LLM behavior with what humans prefer. It typically involves training a "reward model" on human-ranked responses and then using reinforcement learning to fine-tune the LLM to maximize this reward.
- Direct Preference Optimization (DPO): A more recent and simpler alternative to RLHF. Instead of training a separate reward model, DPO directly optimizes the LLM on a preference dataset containing pairs of "chosen" (good) and "rejected" (bad) responses. It uses a specific loss function that encourages the LLM to assign a higher probability to the chosen response than the rejected one.
- Catastrophic Forgetting: A well-known problem in neural networks where a model trained sequentially on multiple tasks forgets the knowledge acquired from earlier tasks. This paper shows it happens in LLM post-training between the SFT and DPO stages.
- Multi-objective Optimization & Pareto Optimality: When you have multiple conflicting goals (like minimizing both the SFT loss and the DPO loss), there's usually no single solution that is best for all of them. A solution is Pareto optimal if you cannot improve one objective without making another one worse. The set of all such solutions is called the Pareto front, representing the best possible trade-offs. The paper's goal is to find methods that can effectively reach points on this Pareto front.
-
Previous Works: The paper positions itself within the context of existing work on LLM alignment and fine-tuning.
- Sequential RLHF and SFT: The authors note that many popular open-source models like
PHI-3
andLLAMA-3
use a sequential post-training recipe. While effective, studies like (Qi et al., 2023) have observed that this harms either alignment or fine-tuning performance. - Reconciling RLHF and SFT:
- Some methods try to merge the two stages by adding regularization to the SFT objective or reformulating it, but these often require using the same dataset for both, which is restrictive.
- Adaptive Model Averaging (AMA): (Lin et al., 2023) proposed a method to balance objectives but it is computationally expensive, requiring three copies of the model. A memory-efficient version exists but moves away from the original objectives.
- Controlling Preferences via Prompts: Other works (Yang et al., 2024; Guo et al., 2024) control different preferences by modifying the training data with special prompts. This is different from the paper's approach, which focuses on the trade-off between the fundamental SFT and DPO tasks themselves, without needing specialized data.
- Sequential RLHF and SFT: The authors note that many popular open-source models like
-
Differentiation: This paper's key innovation is providing a theoretically grounded and computationally efficient solution to the SFT-DPO trade-off problem. Unlike prior work, it formally proves the sub-optimality of the sequential approach and proposes two simple, alternating algorithms (
ALRIGHT
andMAXRIGHT
) that directly address the forgetting issue without the high overhead of naive mixing or the complexities of other methods.
4. Methodology (Core Technology & Implementation)
The paper first formalizes the problem and then proposes its solutions.
-
Principles: The core idea is that sequentially optimizing two different objectives, and , is inherently flawed due to conflicting data distributions and objectives. The model forgets one while learning the other. A better approach is to jointly optimize them. Since naively mixing them is too expensive, the authors propose alternating between updates for each objective in a principled manner.
-
Steps & Procedures: Let's break down the objectives and algorithms.
1. DPO Objective: The goal of DPO is to maximize the likelihood of preferred responses () and minimize the likelihood of rejected ones (). The loss function is:
- : The parameters of the LLM we are training.
- : The probability of the LLM generating response given prompt .
- : A fixed reference model, usually the model before DPO training. This acts as a regularizer to prevent the model from straying too far from its original capabilities.
- : A hyperparameter that controls the strength of the regularization.
- : The sigmoid function, .
- : The preference dataset of prompts, chosen responses, and rejected responses.
- This loss is minimized when the model assigns much higher probabilities to chosen responses () compared to rejected ones (), relative to the reference model.
2. SFT Objective: The goal of SFT is to maximize the likelihood of generating a specific target response . The loss is a standard negative log-likelihood:
- : The SFT dataset of instruction-response pairs.
3. The Problem with Sequential Training (Algorithm 1): The standard approach first trains on one objective for steps, then uses the resulting model as the starting point to train on the second objective for steps. The paper's Theorem 3.3 proves that for any desired trade-off, this method results in a constant, non-zero error (), meaning it never converges to an optimal trade-off point. It's always stuck in a suboptimal region.
Image 2 above illustrates this problem. On the left, the "Sequential" path shows a model first trained with DPO (learning to be helpful but refusing harmful requests) and then with SFT (forgetting its safety alignment to become an "obedient agent"). The proposed methods avoid this. On the right, plot (a) shows the optimization trajectory for sequential training. The model first moves toward the DPO optimum (DPO Opt.), then sharply turns to the SFT optimum (SFT Opt.), ending up far from the ideal point where both losses are low.
4. ALRIGHT (Algorithm 2): This is the first proposed solution. It's a simple yet effective joint training method.
- Core Idea: In each training step, flip a coin. With probability , perform a DPO update. With probability , perform an SFT update.
- Algorithm Flow:
- Initialize model .
- For each training step :
- Sample .
- If : Sample a batch from and update using the DPO gradient.
- If : Sample a batch from and update using the SFT gradient.
- Why it works: In expectation, this process optimizes the mixed objective . Theorem 4.1 shows that the optimality gap of ALRIGHT decreases as , meaning it converges to the desired trade-off point as training progresses. The hyperparameter directly controls the trade-off.
- As shown in Image 2, plot (b), the ALRIGHT trajectory moves more directly toward a balanced point, achieving a much better trade-off than the sequential method.
5. MAXRIGHT (Algorithm 3): This is the second, more adaptive algorithm.
- Core Idea: Instead of randomly choosing which objective to update, intelligently pick the one that is currently "worse". This ensures the training focuses on the lagging objective, promoting a balanced improvement.
- Algorithm Flow:
- Initialize model .
- For each training step :
- Evaluate the weighted sub-optimality for both objectives:
- Here, and are the (pre-computed or estimated) minimum possible loss values for each objective.
- If : Perform a DPO update (as the DPO objective is currently worse).
- Else: Perform an SFT update.
- Practical Consideration: Calculating both losses at every step is expensive. The authors propose a memory-efficient version where both losses are evaluated only every steps. In between, the algorithm uses the "stale" loss values to make its decision, significantly reducing computational overhead.
- As shown in Image 2, plot (c), MAXRIGHT's trajectory is even more direct, heading straight for an ideal trade-off point.
5. Experimental Setup
- Datasets:
- DPO Dataset:
DAHOAs/RM-HH-RLHF
- A dataset containing human feedback, structured as prompts with preferred and rejected responses, designed for aligning models with human preferences (particularly for helpfulness and harmlessness). - SFT Dataset:
VICGALLE/ALPACA-GPT4
- A dataset of English instruction-following examples generated by GPT-4 based on Alpaca prompts. It's used to teach the model to follow instructions effectively.
- DPO Dataset:
- Evaluation Metrics:
- DPO/SFT Optimality Gap:
- Conceptual Definition: Measures how far the current model's loss is from the best possible loss for a given objective. A smaller gap means better performance on that specific task.
- Mathematical Formula:
- Symbol Explanation: is the current loss (either or ) for model . is the minimum achievable loss, which the authors approximate by training a model solely on that objective.
- Ideal Distance:
- Conceptual Definition: Measures the Euclidean distance in the 2D loss space from the current model's performance point to the "ideal point" . This single metric captures the overall trade-off quality. A smaller distance is better.
- Mathematical Formula:
- Symbol Explanation: The symbols are the same as in the optimality gap definitions.
- MMLU (Massive Multitask Language Understanding):
- Conceptual Definition: A widely used benchmark to evaluate a model's general knowledge and problem-solving abilities. It consists of multiple-choice questions across 57 diverse subjects, including humanities, social sciences, and STEM. The paper uses the 1-shot version, where the model sees one example before answering.
- Win Rate:
- Conceptual Definition: A metric for evaluating preference alignment. It measures how often a model's responses are preferred by an evaluator (in this case, GPT-4-TURBO) over a baseline model's responses in head-to-head comparisons. A higher win rate indicates better alignment with human preferences. The evaluation is done using the
AlpacaEval
framework.
- Conceptual Definition: A metric for evaluating preference alignment. It measures how often a model's responses are preferred by an evaluator (in this case, GPT-4-TURBO) over a baseline model's responses in head-to-head comparisons. A higher win rate indicates better alignment with human preferences. The evaluation is done using the
- DPO/SFT Optimality Gap:
- Baselines:
Mix
(Naive Mixing): Optimizes a linear combination of the SFT and DPO loss functions in every step. This is considered a strong but computationally expensive baseline.Sequential
: The standard method of training DPO first, then SFT (or vice-versa).
6. Results & Analysis
The experiments provide strong evidence for the authors' claims.
-
Core Results on Trade-off and Efficiency:
Image 3 above (from experiments on the
Pythia-1B
model) clearly illustrates the main findings. The plots on the left show the Pareto front of different methods.- The
Sequential
method (green crosses) produces a very poor Pareto front. The final models are clustered towards the SFT-optimal side, showing it has "forgotten" the DPO training. - The
Mix
method (purple triangles) andALRIGHT
(red stars) trace a much better Pareto front, demonstrating their ability to achieve a wide range of good trade-offs. MAXRIGHT
(blue stars, in the bottom-left plot) produces final models that are consistently closest to the ideal point (black star), indicating it finds the most balanced solutions. The bar charts on the right quantify this. For any given trade-off ( or(T_DPO, T_SFT)
),ALRIGHT
andMAXRIGHT
achieve much lower DPO/SFT optimality gaps andIdeal Distance
thanSequential
. Crucially, theirIncrease in Runtime (%)
andIncrease in GPU Utilization (%)
are minimal (often near 0%), whereasMix
incurs a huge cost (over 50% more runtime).
- The
-
Results on Real-World Benchmarks (LLAMA3-8B): The following table, transcribed from
Table 1
in the paper, shows performance on downstream tasks.MMLU (1-shot) (%) Win rate (%) λ/(T_SFT, T_DPO) 0.25/(3,1) 0.5/(2, 2) 0.75/(1, 3) 0.25/(3, 1) 0.5/(2,2) 0.75/(1, 3) Sequential 73.18 72.80 72.68 57.19 65.62 59.38 Mix 73.45 73.40 72.29 81.88 84.22 88.42 ALRIGHT 74.66 72.65 75.50 88.28 85.78 87.34 MAXRIGHT 72.35 73.42 74.24 86.56 86.09 83.75 - Analysis:
- On MMLU, which measures general knowledge and reasoning (closer to SFT performance),
ALRIGHT
andMAXRIGHT
achieve scores that are often higher thanSequential
(e.g., 75.50 vs. 72.68). This shows they can maintain or improve instruction-following capabilities. - On Win rate, which measures preference alignment (DPO performance),
ALRIGHT
andMAXRIGHT
are drastically better thanSequential
(e.g., 88.28% vs. 57.19%) and are on par with or even better than the expensiveMix
method. This confirms they successfully mitigate the forgetting of preference alignment.
- On MMLU, which measures general knowledge and reasoning (closer to SFT performance),
- Analysis:
-
Ablations / Parameter Sensitivity (MAXRIGHT):
Image 4 analyzes the
Max Eval. Steps
hyperparameter for the memory-efficientMAXRIGHT
.- : This is the original
MAXRIGHT
where losses are checked every step. It achieves the bestIdeal Distance
but has the highest runtime increase (around 75%). - or
100
: This provides a sweet spot. The performance (Ideal Distance) is nearly as good as checking every step, but the runtime cost is dramatically lower (closer to 0-10% increase). - : The delay is too long. The algorithm uses very stale information, causing the optimization path to oscillate wildly and leading to poor final performance.
- Conclusion: This shows the practical utility of the memory-efficient
MAXRIGHT
. A small value of (e.g., 10) allows it to retain its adaptive advantage while matching the low computational cost ofALRIGHT
andSequential
.
- : This is the original
7. Conclusion & Reflections
-
Conclusion Summary: The paper convincingly demonstrates that the standard sequential post-training pipeline for LLMs is fundamentally flawed due to catastrophic forgetting. It provides both a theoretical proof of this sub-optimality and strong empirical evidence. The proposed
ALRIGHT
andMAXRIGHT
algorithms offer a simple, elegant, and highly effective solution. By alternating between SFT and DPO updates in a principled way, they achieve a superior trade-off between instruction-following and preference alignment. Most importantly, they do so with minimal additional computational resources, making them a practical drop-in replacement for sequential training in real-world LLM development. -
Limitations & Future Work: The paper itself does not explicitly list limitations, but we can infer some:
- Theoretical Assumptions: The theoretical analysis relies on a softmax characterization of the LLM policy. While this is a common simplification, modern LLMs are more complex. The theory provides valuable intuition, but its direct applicability to full-parameter training of complex architectures might be an approximation.
- Hyperparameter Tuning: Both
ALRIGHT
andMAXRIGHT
introduce the trade-off parameter . Finding the "best" for a specific application still requires hyperparameter tuning to navigate the Pareto front. MAXRIGHT
's Dependency: TheMAXRIGHT
algorithm requires knowing or estimating the optimal loss values and , which adds a pre-computation step. While the authors state this is a one-time cost, it is an extra requirement compared toALRIGHT
.- Scope: The work focuses on DPO. Extending the analysis and methods to other preference learning algorithms like PPO-based RLHF would be a valuable next step.
-
Personal Insights & Critique:
- This is a high-quality paper that addresses a very practical and important problem in LLM training. The combination of a strong theoretical result (Theorem 3.3) and a simple, effective, and well-validated practical solution is a significant contribution.
- The core idea of alternating updates is not entirely new in multi-task learning, but its application to the SFT-DPO context, along with the rigorous analysis and the proposal of the adaptive
MAXRIGHT
variant, is novel and impactful. MAXRIGHT
is particularly clever. Its adaptive nature feels more principled than the random alternation ofALRIGHT
and is likely to be more robust in practice, especially in the memory-efficient form.- The paper does an excellent job of demonstrating the "why." The illustrations in Figure 2 and the empirical results in Figure 3 make the problem and the solution intuitively clear. This work could change the default post-training recipe for open-source LLMs, moving the community from a sequential to a joint-alternating approach.
Similar papers
Recommended via semantic vector search.
Margin Matching Preference Optimization: Enhanced Model Alignment with Granular Feedback
MMPO enhances LLM alignment by incorporating relative quality margins into training using soft target probabilities based on the Bradley-Terry model, outperforming baselines on human and AI feedback datasets and achieving state-of-the-art results on RewardBench with improved robu
RL's Razor: Why Online Reinforcement Learning Forgets Less
This paper introduces "RL's Razor," demonstrating why online reinforcement learning prevents catastrophic forgetting better than supervised fine-tuning. Through KL-divergence analysis and experiments, it finds on-policy RL implicitly minimizes distributional shift from the base m
Mitigating the Alignment Tax of RLHF
This paper addresses RLHF's "alignment tax" where LLMs forget pre-trained abilities, creating an alignment-forgetting trade-off. It shows simple model averaging surprisingly achieves the best Pareto front. Building on theoretical insights, a novel Heterogeneous Model Averaging (H
Keeping LLMs Aligned After Fine-tuning: The Crucial Role of Prompt Templates
This work shows prompt templates crucially impact LLM safety post-fine-tuning and proposes PTST: fine-tune without safety prompts but add them at inference, effectively reducing unsafe behaviors while maintaining performance across major chat models.
Discussion
Leave a comment
No comments yet. Start the discussion!