-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
42 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,89 +1,69 @@ | ||
Training large language models (LLMs) using **policy gradient methods** with techniques like **KL penalty gradients** and **Proximal Policy Optimization (PPO)** involves optimizing the model's policy to maximize rewards from a specific objective function. Here's a breakdown: | ||
To train large language models (LLMs) using policy gradient methods, particularly with techniques like **KL penalty gradients** and **Proximal Policy Optimization (PPO)**, we use mathematical formulations to guide updates to the model's parameters. Let me break down each method with the respective formulas and their explanations. | ||
|
||
--- | ||
|
||
### **1. Background: Policy Gradient in LLMs** | ||
Policy gradient methods train a model by directly optimizing the policy (probability distribution over actions) using gradients of a reward signal. In the context of LLMs: | ||
### 1. **Policy Gradient with KL Penalty** | ||
The objective with a KL penalty balances exploration (new predictions) and staying close to the current policy (to prevent drastic changes). The loss function is: | ||
|
||
- **Policy**: The model's output probabilities over tokens or sequences (e.g., \( \pi_\theta(a|s) \), where \( a \) is an action or token, and \( s \) is the input context). | ||
- **Objective**: Maximize expected rewards, such as alignment with user preferences, quality of generated text, or adherence to specific constraints. | ||
|
||
The optimization problem is: | ||
\[ | ||
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] | ||
L(\theta) = \mathbb{E}_{x \sim \text{data}, a \sim \pi_\theta} \left[ \log \pi_\theta(a|x) R(a, x) - \beta \text{KL}(\pi_\theta || \pi_{\text{old}}) \right] | ||
\] | ||
where \( R \) is the reward. | ||
|
||
Policy gradient methods adjust the parameters \( \theta \) in the direction of: | ||
\[ | ||
\nabla_\theta \mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot R \right] | ||
\] | ||
#### Explanation of terms: | ||
- \( \theta \): The parameters of the policy (e.g., weights of the model). | ||
- \( x \): Context or input (e.g., a sentence prefix). | ||
- \( a \): Action (e.g., the token chosen by the model). | ||
- \( \pi_\theta(a|x) \): The probability of taking action \( a \) under policy \( \pi_\theta \). | ||
- \( R(a, x) \): The reward for taking action \( a \) in context \( x \) (e.g., based on human feedback or task-specific metrics). | ||
- \( \beta \): A hyperparameter controlling the strength of the KL penalty. | ||
- \( \text{KL}(\pi_\theta || \pi_{\text{old}}) \): The Kullback-Leibler divergence between the current policy (\( \pi_\theta \)) and the old policy (\( \pi_{\text{old}} \)). | ||
|
||
#### Steps: | ||
1. The term \( \log \pi_\theta(a|x) R(a, x) \) encourages actions with higher rewards. | ||
2. The KL term \( \beta \text{KL}(\pi_\theta || \pi_{\text{old}}) \) penalizes large deviations from the previous policy. | ||
|
||
--- | ||
|
||
### **2. KL Penalty Gradients** | ||
When fine-tuning an LLM, the new policy \( \pi_\theta \) should not deviate too much from the original policy \( \pi_{\text{ref}} \). The **Kullback-Leibler (KL) divergence** penalizes large deviations: | ||
\[ | ||
\text{KL}(\pi_\theta || \pi_{\text{ref}}) = \sum_a \pi_\theta(a|s) \log \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)} | ||
\] | ||
### 2. **Proximal Policy Optimization (PPO)** | ||
|
||
PPO is designed to optimize policies while ensuring updates are constrained to prevent overly aggressive changes. The clipped objective is: | ||
|
||
The **KL penalty term** is added to the loss: | ||
\[ | ||
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] - \beta \cdot \mathbb{E}_{\pi_\theta} \left[ \text{KL}(\pi_\theta || \pi_{\text{ref}}) \right] | ||
L(\theta) = \mathbb{E}_{x, a} \left[ \min \left( r_\theta(a|x) \hat{A}(x, a), \text{clip}(r_\theta(a|x), 1 - \epsilon, 1 + \epsilon) \hat{A}(x, a) \right) \right] | ||
\] | ||
|
||
Here: | ||
- \( \beta \): Controls the strength of the penalty. | ||
- Intuition: Encourages the policy to remain close to the reference distribution unless there's a strong reward incentive to deviate. | ||
#### Explanation of terms: | ||
- \( r_\theta(a|x) = \frac{\pi_\theta(a|x)}{\pi_{\text{old}}(a|x)} \): The ratio of probabilities under the new and old policies. | ||
- \( \hat{A}(x, a) \): The advantage function estimating how much better action \( a \) is compared to the average action at \( x \). | ||
- \( \epsilon \): A small hyperparameter controlling the clipping range. | ||
- \( \text{clip}(r_\theta(a|x), 1 - \epsilon, 1 + \epsilon) \): Ensures that the probability ratio remains within a specified range, preventing overly large updates. | ||
|
||
The gradient becomes: | ||
\[ | ||
\nabla_\theta \mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot (R - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}})) \right] | ||
\] | ||
#### Steps: | ||
1. \( r_\theta(a|x) \hat{A}(x, a) \): Encourages actions that have a high advantage. | ||
2. The clipping \( \text{clip}(r_\theta(a|x), 1 - \epsilon, 1 + \epsilon) \) ensures the update to the policy stays within a safe region to avoid destabilization. | ||
3. The objective is the minimum of the unclipped and clipped terms, which avoids the risk of overly optimistic updates. | ||
|
||
--- | ||
|
||
### **3. Proximal Policy Optimization (PPO)** | ||
**PPO** is a popular and robust policy gradient algorithm. It improves upon basic policy gradient methods by ensuring updates are **stable** and **limited**. | ||
### 3. **KL Penalty Gradient in PPO** | ||
|
||
Some variations of PPO incorporate an explicit KL penalty instead of clipping. The objective in such cases is: | ||
|
||
Key innovations: | ||
- **Clipped Surrogate Objective**: PPO modifies the objective to prevent large updates to the policy: | ||
\[ | ||
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \min(r_t \cdot A, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot A) \right] | ||
\] | ||
where: | ||
- \( r_t = \frac{\pi_\theta(a|s)}{\pi_{\text{old}}(a|s)} \): Ratio of new to old policy probabilities. | ||
- \( A \): Advantage estimate (how much better an action is compared to the baseline). | ||
- \( \epsilon \): Clip parameter (e.g., 0.2). | ||
|
||
- **KL Penalty or Constraint**: Some versions of PPO explicitly incorporate a KL divergence term to control deviation: | ||
\[ | ||
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}}) | ||
L(\theta) = \mathbb{E}_{x, a} \left[ r_\theta(a|x) \hat{A}(x, a) - \beta \text{KL}(\pi_\theta || \pi_{\text{old}}) \right] | ||
\] | ||
|
||
**Benefits**: | ||
- Ensures policy updates stay within a "trust region," avoiding dramatic changes. | ||
- Balances exploration of new policies and staying close to the reference policy. | ||
|
||
--- | ||
#### Explanation: | ||
This combines the clipping-free PPO update with the KL penalty term to ensure controlled policy updates. | ||
|
||
### **4. Putting It All Together** | ||
When training LLMs with PPO and KL penalty gradients: | ||
1. **Initialize** with a reference model \( \pi_{\text{ref}} \) (e.g., a pre-trained LLM). | ||
2. **Collect data** using the current policy \( \pi_\theta \) by sampling outputs and computing rewards (e.g., via human feedback or a reward model). | ||
3. **Compute gradients** using: | ||
- Reward signal \( R \). | ||
- KL divergence penalty to enforce similarity to \( \pi_{\text{ref}} \). | ||
- Clipped objective to stabilize updates. | ||
4. **Update policy** iteratively until convergence. | ||
- The **advantage term** encourages exploration based on observed rewards. | ||
- The **KL term** penalizes divergence from the previous policy, similar to the earlier KL penalty method. | ||
|
||
--- | ||
|
||
### **5. Practical Considerations** | ||
- **Reward Design**: Rewards must reflect the desired behavior (e.g., correctness, coherence, user preference). | ||
- **Hyperparameters**: | ||
- \( \beta \): KL penalty weight. | ||
- \( \epsilon \): PPO clip threshold. | ||
- **Efficiency**: Large-scale LLM training requires distributed optimization and batching strategies. | ||
### Summary of Differences: | ||
- **KL Penalty**: Explicitly penalizes divergence from the previous policy with a hyperparameter \( \beta \). | ||
- **PPO**: Uses clipping to constrain updates instead of an explicit KL penalty. It ensures stability while allowing flexibility. | ||
- **KL in PPO**: A hybrid approach that combines advantage-driven updates with a KL penalty for more explicit control. | ||
|
||
This approach has been successfully used in reinforcement learning with human feedback (RLHF) for fine-tuning LLMs like OpenAI's GPT models. | ||
Each approach has trade-offs and can be tuned based on the model size, task complexity, and stability requirements. Let me know if you want a deeper dive into any of these! |