-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
89 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
Training large language models (LLMs) using **policy gradient methods** with techniques like **KL penalty gradients** and **Proximal Policy Optimization (PPO)** involves optimizing the model's policy to maximize rewards from a specific objective function. Here's a breakdown: | ||
|
||
--- | ||
|
||
### **1. Background: Policy Gradient in LLMs** | ||
Policy gradient methods train a model by directly optimizing the policy (probability distribution over actions) using gradients of a reward signal. In the context of LLMs: | ||
|
||
- **Policy**: The model's output probabilities over tokens or sequences (e.g., \( \pi_\theta(a|s) \), where \( a \) is an action or token, and \( s \) is the input context). | ||
- **Objective**: Maximize expected rewards, such as alignment with user preferences, quality of generated text, or adherence to specific constraints. | ||
|
||
The optimization problem is: | ||
\[ | ||
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] | ||
\] | ||
where \( R \) is the reward. | ||
|
||
Policy gradient methods adjust the parameters \( \theta \) in the direction of: | ||
\[ | ||
\nabla_\theta \mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot R \right] | ||
\] | ||
|
||
--- | ||
|
||
### **2. KL Penalty Gradients** | ||
When fine-tuning an LLM, the new policy \( \pi_\theta \) should not deviate too much from the original policy \( \pi_{\text{ref}} \). The **Kullback-Leibler (KL) divergence** penalizes large deviations: | ||
\[ | ||
\text{KL}(\pi_\theta || \pi_{\text{ref}}) = \sum_a \pi_\theta(a|s) \log \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} | ||
\] | ||
|
||
The **KL penalty term** is added to the loss: | ||
\[ | ||
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] - \beta \cdot \mathbb{E}_{\pi_\theta} \left[ \text{KL}(\pi_\theta || \pi_{\text{ref}}) \right] | ||
\] | ||
|
||
Here: | ||
- \( \beta \): Controls the strength of the penalty. | ||
- Intuition: Encourages the policy to remain close to the reference distribution unless there's a strong reward incentive to deviate. | ||
|
||
The gradient becomes: | ||
\[ | ||
\nabla_\theta \mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot (R - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}})) \right] | ||
\] | ||
|
||
--- | ||
|
||
### **3. Proximal Policy Optimization (PPO)** | ||
**PPO** is a popular and robust policy gradient algorithm. It improves upon basic policy gradient methods by ensuring updates are **stable** and **limited**. | ||
|
||
Key innovations: | ||
- **Clipped Surrogate Objective**: PPO modifies the objective to prevent large updates to the policy: | ||
\[ | ||
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \min(r_t \cdot A, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot A) \right] | ||
\] | ||
where: | ||
- \( r_t = \frac{\pi_\theta(a|s)}{\pi_{\text{old}}(a|s)} \): Ratio of new to old policy probabilities. | ||
- \( A \): Advantage estimate (how much better an action is compared to the baseline). | ||
- \( \epsilon \): Clip parameter (e.g., 0.2). | ||
|
||
- **KL Penalty or Constraint**: Some versions of PPO explicitly incorporate a KL divergence term to control deviation: | ||
\[ | ||
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}}) | ||
\] | ||
|
||
**Benefits**: | ||
- Ensures policy updates stay within a "trust region," avoiding dramatic changes. | ||
- Balances exploration of new policies and staying close to the reference policy. | ||
|
||
--- | ||
|
||
### **4. Putting It All Together** | ||
When training LLMs with PPO and KL penalty gradients: | ||
1. **Initialize** with a reference model \( \pi_{\text{ref}} \) (e.g., a pre-trained LLM). | ||
2. **Collect data** using the current policy \( \pi_\theta \) by sampling outputs and computing rewards (e.g., via human feedback or a reward model). | ||
3. **Compute gradients** using: | ||
- Reward signal \( R \). | ||
- KL divergence penalty to enforce similarity to \( \pi_{\text{ref}} \). | ||
- Clipped objective to stabilize updates. | ||
4. **Update policy** iteratively until convergence. | ||
|
||
--- | ||
|
||
### **5. Practical Considerations** | ||
- **Reward Design**: Rewards must reflect the desired behavior (e.g., correctness, coherence, user preference). | ||
- **Hyperparameters**: | ||
- \( \beta \): KL penalty weight. | ||
- \( \epsilon \): PPO clip threshold. | ||
- **Efficiency**: Large-scale LLM training requires distributed optimization and batching strategies. | ||
|
||
This approach has been successfully used in reinforcement learning with human feedback (RLHF) for fine-tuning LLMs like OpenAI's GPT models. |