Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
locchh committed Jan 2, 2025
1 parent a5fa240 commit c7bd1e7
Showing 1 changed file with 89 additions and 0 deletions.
89 changes: 89 additions & 0 deletions docs/PPO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
Training large language models (LLMs) using **policy gradient methods** with techniques like **KL penalty gradients** and **Proximal Policy Optimization (PPO)** involves optimizing the model's policy to maximize rewards from a specific objective function. Here's a breakdown:

---

### **1. Background: Policy Gradient in LLMs**
Policy gradient methods train a model by directly optimizing the policy (probability distribution over actions) using gradients of a reward signal. In the context of LLMs:

- **Policy**: The model's output probabilities over tokens or sequences (e.g., \( \pi_\theta(a|s) \), where \( a \) is an action or token, and \( s \) is the input context).
- **Objective**: Maximize expected rewards, such as alignment with user preferences, quality of generated text, or adherence to specific constraints.

The optimization problem is:
\[
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right]
\]
where \( R \) is the reward.

Policy gradient methods adjust the parameters \( \theta \) in the direction of:
\[
\nabla_\theta \mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot R \right]
\]

---

### **2. KL Penalty Gradients**
When fine-tuning an LLM, the new policy \( \pi_\theta \) should not deviate too much from the original policy \( \pi_{\text{ref}} \). The **Kullback-Leibler (KL) divergence** penalizes large deviations:
\[
\text{KL}(\pi_\theta || \pi_{\text{ref}}) = \sum_a \pi_\theta(a|s) \log \frac{\pi_\theta(a|s)}{\pi_{\text{ref}}(a|s)}
\]

The **KL penalty term** is added to the loss:
\[
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] - \beta \cdot \mathbb{E}_{\pi_\theta} \left[ \text{KL}(\pi_\theta || \pi_{\text{ref}}) \right]
\]

Here:
- \( \beta \): Controls the strength of the penalty.
- Intuition: Encourages the policy to remain close to the reference distribution unless there's a strong reward incentive to deviate.

The gradient becomes:
\[
\nabla_\theta \mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \nabla_\theta \log \pi_\theta(a|s) \cdot (R - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}})) \right]
\]

---

### **3. Proximal Policy Optimization (PPO)**
**PPO** is a popular and robust policy gradient algorithm. It improves upon basic policy gradient methods by ensuring updates are **stable** and **limited**.

Key innovations:
- **Clipped Surrogate Objective**: PPO modifies the objective to prevent large updates to the policy:
\[
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ \min(r_t \cdot A, \text{clip}(r_t, 1-\epsilon, 1+\epsilon) \cdot A) \right]
\]
where:
- \( r_t = \frac{\pi_\theta(a|s)}{\pi_{\text{old}}(a|s)} \): Ratio of new to old policy probabilities.
- \( A \): Advantage estimate (how much better an action is compared to the baseline).
- \( \epsilon \): Clip parameter (e.g., 0.2).

- **KL Penalty or Constraint**: Some versions of PPO explicitly incorporate a KL divergence term to control deviation:
\[
\mathcal{L}(\theta) = \mathbb{E}_{\pi_\theta} \left[ R \right] - \beta \cdot \text{KL}(\pi_\theta || \pi_{\text{ref}})
\]

**Benefits**:
- Ensures policy updates stay within a "trust region," avoiding dramatic changes.
- Balances exploration of new policies and staying close to the reference policy.

---

### **4. Putting It All Together**
When training LLMs with PPO and KL penalty gradients:
1. **Initialize** with a reference model \( \pi_{\text{ref}} \) (e.g., a pre-trained LLM).
2. **Collect data** using the current policy \( \pi_\theta \) by sampling outputs and computing rewards (e.g., via human feedback or a reward model).
3. **Compute gradients** using:
- Reward signal \( R \).
- KL divergence penalty to enforce similarity to \( \pi_{\text{ref}} \).
- Clipped objective to stabilize updates.
4. **Update policy** iteratively until convergence.

---

### **5. Practical Considerations**
- **Reward Design**: Rewards must reflect the desired behavior (e.g., correctness, coherence, user preference).
- **Hyperparameters**:
- \( \beta \): KL penalty weight.
- \( \epsilon \): PPO clip threshold.
- **Efficiency**: Large-scale LLM training requires distributed optimization and batching strategies.

This approach has been successfully used in reinforcement learning with human feedback (RLHF) for fine-tuning LLMs like OpenAI's GPT models.

0 comments on commit c7bd1e7

Please sign in to comment.