Steps & Procedures:
Let's break down the objectives and algorithms.
1. DPO Objective:
The goal of DPO is to maximize the likelihood of preferred responses (yw) and minimize the likelihood of rejected ones (yℓ). The loss function is:
fDPO(θ):=−E(x,yw,yℓ)∼DDPO[logσ(βlogπref(yw∣x)πθ(yw∣x)−βlogπref(yℓ∣x)πθ(yℓ∣x))]
- θ: The parameters of the LLM we are training.
- πθ(y∣x): The probability of the LLM generating response y given prompt x.
- πref: A fixed reference model, usually the model before DPO training. This acts as a regularizer to prevent the model from straying too far from its original capabilities.
- β: A hyperparameter that controls the strength of the regularization.
- σ: The sigmoid function, σ(z)=1/(1+e−z).
- DDPO: The preference dataset of prompts, chosen responses, and rejected responses.
- This loss is minimized when the model assigns much higher probabilities to chosen responses (yw) compared to rejected ones (yℓ), relative to the reference model.
2. SFT Objective:
The goal of SFT is to maximize the likelihood of generating a specific target response y. The loss is a standard negative log-likelihood:
fSFT(θ):=−E(x,y)∼DSFT[log(πθ(y∣x))]
- DSFT: The SFT dataset of instruction-response pairs.
3. The Problem with Sequential Training (Algorithm 1):
The standard approach first trains on one objective for TDPO steps, then uses the resulting model as the starting point to train on the second objective for TSFT steps. The paper's Theorem 3.3 proves that for any desired trade-off, this method results in a constant, non-zero error (Ω(1)), meaning it never converges to an optimal trade-off point. It's always stuck in a suboptimal region.
Image 2 above illustrates this problem. On the left, the "Sequential" path shows a model first trained with DPO (learning to be helpful but refusing harmful requests) and then with SFT (forgetting its safety alignment to become an "obedient agent"). The proposed methods avoid this. On the right, plot (a) shows the optimization trajectory for sequential training. The model first moves toward the DPO optimum (DPO Opt.), then sharply turns to the SFT optimum (SFT Opt.), ending up far from the ideal point where both losses are low.
4. ALRIGHT (Algorithm 2):
This is the first proposed solution. It's a simple yet effective joint training method.
- Core Idea: In each training step, flip a coin. With probability λ, perform a DPO update. With probability 1−λ, perform an SFT update.
- Algorithm Flow:
- Initialize model θ1.
- For each training step t=1,…,T−1:
- Sample it∼Bernoulli(λ).
- If it=1: Sample a batch from DDPO and update θ using the DPO gradient.
- If it=0: Sample a batch from DSFT and update θ using the SFT gradient.
- Why it works: In expectation, this process optimizes the mixed objective λfDPO(θ)+(1−λ)fSFT(θ). Theorem 4.1 shows that the optimality gap of ALRIGHT decreases as O(logT/T), meaning it converges to the desired trade-off point as training progresses. The hyperparameter λ directly controls the trade-off.
- As shown in Image 2, plot (b), the ALRIGHT trajectory moves more directly toward a balanced point, achieving a much better trade-off than the sequential method.
5. MAXRIGHT (Algorithm 3):
This is the second, more adaptive algorithm.
- Core Idea: Instead of randomly choosing which objective to update, intelligently pick the one that is currently "worse". This ensures the training focuses on the lagging objective, promoting a balanced improvement.
- Algorithm Flow:
- Initialize model θ1.
- For each training step t=1,…,T−1:
- Evaluate the weighted sub-optimality for both objectives:
- fˉ1,λ(θt)=λ(fDPO(θt)−fDPO∗)
- fˉ2,λ(θt)=(1−λ)(fSFT(θt)−fSFT∗)
- Here, fDPO∗ and fSFT∗ are the (pre-computed or estimated) minimum possible loss values for each objective.
- If fˉ1,λ(θt)≥fˉ2,λ(θt): Perform a DPO update (as the DPO objective is currently worse).
- Else: Perform an SFT update.
- Practical Consideration: Calculating both losses at every step is expensive. The authors propose a memory-efficient version where both losses are evaluated only every k steps. In between, the algorithm uses the "stale" loss values to make its decision, significantly reducing computational overhead.
- As shown in Image 2, plot (c), MAXRIGHT's trajectory is even more direct, heading straight for an ideal trade-off point.