Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
locchh committed Jan 2, 2025
1 parent c7bd1e7 commit e2528e3
Showing 1 changed file with 42 additions and 62 deletions.
104 changes: 42 additions & 62 deletions docs/PPO.md
Original file line number Diff line number Diff line change
@@ -1,89 +1,69 @@
Training large language models (LLMs) using **policy gradient methods** with techniques like **KL penalty gradients** and **Proximal Policy Optimization (PPO)** involves optimizing the model's policy to maximize rewards from a specific objective function. Here's a breakdown:
To train large language models (LLMs) using policy gradient methods, particularly with techniques like **KL penalty gradients** and **Proximal Policy Optimization (PPO)**, we use mathematical formulations to guide updates to the model's parameters. Let me break down each method with the respective formulas and their explanations.

---

### **1. Background: Policy Gradient in LLMs**
Policy gradient methods train a model by directly optimizing the policy (probability distribution over actions) using gradients of a reward signal. In the context of LLMs:
### 1. **Policy Gradient with KL Penalty**
The objective with a KL penalty balances exploration (new predictions) and staying close to the current policy (to prevent drastic changes). The loss function is:

- **Policy**: The model's output probabilities over tokens or sequences (e.g., \( \pi_\theta(a|s) \), where \( a \) is an action or token, and \( s \) is the input context).
- **Objective**: Maximize expected rewards, such as alignment with user preferences, quality of generated text, or adherence to specific constraints.

The optimization problem is:
\[
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right]
L(\theta) = \mathbb{E}_{x \sim \text{data}, a \sim \pi_\theta} \left[ \log \pi_\theta(a|x) R(a, x) - \beta \text{KL}(\pi_\theta || \pi_{\text{old}}) \right]
\]
where \( R \) is the reward.

Policy gradient methods adjust the parameters \( \theta \) in the direction of:
\[
\nabla_\theta \mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot R \right]
\]
#### Explanation of terms:
- \( \theta \): The parameters of the policy (e.g., weights of the model).
- \( x \): Context or input (e.g., a sentence prefix).
- \( a \): Action (e.g., the token chosen by the model).
- \( \pi_\theta(a|x) \): The probability of taking action \( a \) under policy \( \pi_\theta \).
- \( R(a, x) \): The reward for taking action \( a \) in context \( x \) (e.g., based on human feedback or task-specific metrics).
- \( \beta \): A hyperparameter controlling the strength of the KL penalty.
- \( \text{KL}(\pi_\theta || \pi_{\text{old}}) \): The Kullback-Leibler divergence between the current policy (\( \pi_\theta \)) and the old policy (\( \pi_{\text{old}} \)).

#### Steps:
1. The term \( \log \pi_\theta(a|x) R(a, x) \) encourages actions with higher rewards.
2. The KL term \( \beta \text{KL}(\pi_\theta || \pi_{\text{old}}) \) penalizes large deviations from the previous policy.

---

### **2. KL Penalty Gradients**
When fine-tuning an LLM, the new policy \( \pi_\theta \) should not deviate too much from the original policy \( \pi_{\text{ref}} \). The **Kullback-Leibler (KL) divergence** penalizes large deviations:
\[
\text{KL}(\pi_\theta || \pi_{\text{ref}}) = \sum_a \pi_\theta(a|s) \log \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}
\]
### 2. **Proximal Policy Optimization (PPO)**

PPO is designed to optimize policies while ensuring updates are constrained to prevent overly aggressive changes. The clipped objective is:

The **KL penalty term** is added to the loss:
\[
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] - \beta \cdot \mathbb{E}_{\pi_\theta} \left[ \text{KL}(\pi_\theta || \pi_{\text{ref}}) \right]
L(\theta) = \mathbb{E}_{x, a} \left[ \min \left( r_\theta(a|x) \hat{A}(x, a), \text{clip}(r_\theta(a|x), 1 - \epsilon, 1 + \epsilon) \hat{A}(x, a) \right) \right]
\]

Here:
- \( \beta \): Controls the strength of the penalty.
- Intuition: Encourages the policy to remain close to the reference distribution unless there's a strong reward incentive to deviate.
#### Explanation of terms:
- \( r_\theta(a|x) = \frac{\pi_\theta(a|x)}{\pi_{\text{old}}(a|x)} \): The ratio of probabilities under the new and old policies.
- \( \hat{A}(x, a) \): The advantage function estimating how much better action \( a \) is compared to the average action at \( x \).
- \( \epsilon \): A small hyperparameter controlling the clipping range.
- \( \text{clip}(r_\theta(a|x), 1 - \epsilon, 1 + \epsilon) \): Ensures that the probability ratio remains within a specified range, preventing overly large updates.

The gradient becomes:
\[
\nabla_\theta \mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot (R - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}})) \right]
\]
#### Steps:
1. \( r_\theta(a|x) \hat{A}(x, a) \): Encourages actions that have a high advantage.
2. The clipping \( \text{clip}(r_\theta(a|x), 1 - \epsilon, 1 + \epsilon) \) ensures the update to the policy stays within a safe region to avoid destabilization.
3. The objective is the minimum of the unclipped and clipped terms, which avoids the risk of overly optimistic updates.

---

### **3. Proximal Policy Optimization (PPO)**
**PPO** is a popular and robust policy gradient algorithm. It improves upon basic policy gradient methods by ensuring updates are **stable** and **limited**.
### 3. **KL Penalty Gradient in PPO**

Some variations of PPO incorporate an explicit KL penalty instead of clipping. The objective in such cases is:

Key innovations:
- **Clipped Surrogate Objective**: PPO modifies the objective to prevent large updates to the policy:
\[
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \min(r_t \cdot A, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot A) \right]
\]
where:
- \( r_t = \frac{\pi_\theta(a|s)}{\pi_{\text{old}}(a|s)} \): Ratio of new to old policy probabilities.
- \( A \): Advantage estimate (how much better an action is compared to the baseline).
- \( \epsilon \): Clip parameter (e.g., 0.2).

- **KL Penalty or Constraint**: Some versions of PPO explicitly incorporate a KL divergence term to control deviation:
\[
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}})
L(\theta) = \mathbb{E}_{x, a} \left[ r_\theta(a|x) \hat{A}(x, a) - \beta \text{KL}(\pi_\theta || \pi_{\text{old}}) \right]
\]

**Benefits**:
- Ensures policy updates stay within a "trust region," avoiding dramatic changes.
- Balances exploration of new policies and staying close to the reference policy.

---
#### Explanation:
This combines the clipping-free PPO update with the KL penalty term to ensure controlled policy updates.

### **4. Putting It All Together**
When training LLMs with PPO and KL penalty gradients:
1. **Initialize** with a reference model \( \pi_{\text{ref}} \) (e.g., a pre-trained LLM).
2. **Collect data** using the current policy \( \pi_\theta \) by sampling outputs and computing rewards (e.g., via human feedback or a reward model).
3. **Compute gradients** using:
- Reward signal \( R \).
- KL divergence penalty to enforce similarity to \( \pi_{\text{ref}} \).
- Clipped objective to stabilize updates.
4. **Update policy** iteratively until convergence.
- The **advantage term** encourages exploration based on observed rewards.
- The **KL term** penalizes divergence from the previous policy, similar to the earlier KL penalty method.

---

### **5. Practical Considerations**
- **Reward Design**: Rewards must reflect the desired behavior (e.g., correctness, coherence, user preference).
- **Hyperparameters**:
- \( \beta \): KL penalty weight.
- \( \epsilon \): PPO clip threshold.
- **Efficiency**: Large-scale LLM training requires distributed optimization and batching strategies.
### Summary of Differences:
- **KL Penalty**: Explicitly penalizes divergence from the previous policy with a hyperparameter \( \beta \).
- **PPO**: Uses clipping to constrain updates instead of an explicit KL penalty. It ensures stability while allowing flexibility.
- **KL in PPO**: A hybrid approach that combines advantage-driven updates with a KL penalty for more explicit control.

This approach has been successfully used in reinforcement learning with human feedback (RLHF) for fine-tuning LLMs like OpenAI's GPT models.
Each approach has trade-offs and can be tuned based on the model size, task complexity, and stability requirements. Let me know if you want a deeper dive into any of these!

0 comments on commit e2528e3

Please sign in to comment.