[trainer,hparams,docs] feat: add CRD (Centered Reward Distillation) algorithm#121
[trainer,hparams,docs] feat: add CRD (Centered Reward Distillation) algorithm#121yuanzhi-zhu wants to merge 4 commits intoX-GenGroup:mainfrom
Conversation
…lgorithm Implements Centered Reward Distillation (arXiv:2603.14128) as a new decoupled RL trainer for flow-matching models. Key changes: - `trainers/crd.py`: Full CRDTrainer implementation with old/sampling model snapshots, dual-direction centering loss, adaptive KL, and per-step velocity-space implicit reward estimation - `hparams/training_args.py`: CRDTrainingArguments with paper-aligned defaults (decay schedules, kl_beta=0.1, kl_cfg=4.5) - `trainers/registry.py`: Register 'crd' key - `hparams/__init__.py`: Export CRDTrainingArguments - `examples/crd/lora/sd3_5.yaml`: SD3.5 + OCR example config matching paper Table 3 hyperparameters (K=24, 2 grad steps, timestep_range=0.99) - `guidance/algorithms.md`: CRD section with hyperparameter reference and centering modes table - `.agents/knowledge/architecture.md`: Add CRD to trainer registry table Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
Adds a new decoupled RL trainer implementing Centered Reward Distillation (CRD) for flow-matching models, along with corresponding hyperparameters, registry wiring, and user/internal documentation.
Changes:
- Introduces
CRDTrainerwith old/sampling parameter snapshots, centered reward-matching loss, and reference-model KL regularization. - Adds
CRDTrainingArgumentsand registers the new'crd'trainer + hparams key. - Documents CRD usage/hyperparameters and provides an SD3.5 + OCR example config.
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| src/flow_factory/trainers/registry.py | Registers 'crd' → CRDTrainer for dynamic trainer loading. |
| src/flow_factory/trainers/crd.py | New CRD trainer implementation (sampling/optimization/loss/KL + snapshot decay). |
| src/flow_factory/hparams/training_args.py | Adds CRDTrainingArguments and registers it under 'crd'. |
| src/flow_factory/hparams/init.py | Exposes CRDTrainingArguments from the hparams package. |
| guidance/algorithms.md | Adds CRD algorithm documentation and hyperparameter reference. |
| examples/crd/lora/sd3_5.yaml | Adds a paper-aligned CRD LoRA config example for SD3.5 + OCR reward. |
| .agents/knowledge/architecture.md | Updates internal architecture docs to include CRD in the trainer registry table. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Gather r_theta across all GPUs for centering | ||
| r_theta_gathered = self.accelerator.gather(r_theta_local.detach()).to( | ||
| self.accelerator.device | ||
| ) | ||
|
|
||
| # 5. Compute advantages for CRD centering | ||
| adv = batch['advantage'] | ||
| adv_clip_range = self.training_args.adv_clip_range | ||
| adv_clipped = torch.clamp(adv, adv_clip_range[0], adv_clip_range[1]) | ||
|
|
||
| # Normalize to [0, 1] | ||
| normalized_adv = (adv_clipped / max(adv_clip_range)) / 2.0 + 0.5 | ||
| adv_cur_rank = torch.clamp(normalized_adv, 0, 1) | ||
|
|
||
| # Gather advantages across all GPUs | ||
| adv_cur = self.accelerator.gather(adv_cur_rank.detach()).to( | ||
| self.accelerator.device | ||
| ) | ||
|
|
||
| # 6. Centered Reward Distillation loss (supports dual-direction centering) | ||
| ori_policy_loss = self._compute_crd_loss( | ||
| adv_cur=adv_cur, | ||
| adv_cur_rank=adv_cur_rank, | ||
| r_theta_gathered=r_theta_gathered, | ||
| r_theta_local=r_theta_local, | ||
| ) |
| def _blend_named_params(self, name: str, decay: float): | ||
| """ | ||
| Blend a named parameter snapshot towards the current trainable parameters. | ||
|
|
||
| Formula: ``snapshot = decay * snapshot + (1 - decay) * current`` | ||
|
|
||
| Args: | ||
| name: Name of the parameter snapshot. | ||
| decay: Blending coefficient. 0.0 = full copy, 1.0 = no change. | ||
| """ | ||
| if decay <= 0.0: | ||
| # Full copy from current params (no blending) | ||
| self.adapter.update_named_parameters(name) | ||
| elif decay >= 1.0: | ||
| # Keep snapshot unchanged (fully offline) | ||
| pass | ||
| else: | ||
| # Exponential blending: snapshot = decay * snapshot + (1 - decay) * current | ||
| info = self.adapter._named_parameters[name] | ||
| current_params = self.adapter._get_component_parameters(info.target_components) | ||
| with torch.no_grad(): | ||
| for ema_param, param in zip(info.ema_wrapper.ema_parameters, current_params, strict=True): | ||
| ema_param.data.mul_(decay).add_( | ||
| param.detach().to(ema_param.device), alpha=(1.0 - decay) | ||
| ) | ||
|
|
| if self.reward_adaptive_kl: | ||
| # Linearly scale KL based on reward value | ||
| raw_reward = adv_cur_rank # Already in [0, 1] | ||
| base_beta = 1e-4 | ||
| min_coef = base_beta / max(self.kl_beta, 1e-8) | ||
| kl_loss = self.kl_beta * torch.mean((min_coef + raw_reward * (1 - min_coef)) * kl_div) | ||
| else: |
| ### Centering Modes (`weight_temp`) | ||
|
|
||
| | `weight_temp` | Mode | Description | | ||
| |---|---|---| | ||
| | `< 0` | Uniform (τ→∞) | Simple mean centering; recommended default | | ||
| | `== 0` | Hard selection | Positive pool (adv > 0) vs negative pool (adv < 0) | | ||
| | `> 0` | Softmax temperature | Dual-direction: `softmax(adv/τ)` and `softmax(-adv/τ)` | | ||
|
|
|
If you have a chance, it will be great if you can paste any training curves that show the performance increase from your algorithm so that we can verify the correctness of implementation. |
|
@yuanzhi-zhu |
|
@Jayce-Ping @MengHao666 |
|
The example config sets |
Resolve conflicts between CRD trainer (this PR) and DGPO trainer (X-GenGroup#133) added to main since branching: - training_args.py: Place DGPOTrainingArguments before CRDTrainingArguments, each class retains its own fields and defaults - algorithms.md: Adopt main's reference numbering, add CRD as ref [13] - crd.py: Add _maybe_offload_samples_to_cpu() call matching the pattern established in GRPO/GRPOGuard on main (PR X-GenGroup#149) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1. Fix: add `get_preprocess_guidance_scale()` override to `CRDTrainingArguments` — returns `max(guidance_scale, kl_cfg)` so negative prompts are encoded during preprocessing even when `guidance_scale=1.0` (sampling without CFG) but `kl_cfg=4.5` (reference model KL uses CFG). Also add `kl_type` validation in `__post_init__` to match NFT/AWM/GRPO pattern. 2. Doc: add note in `CRDTrainer.optimize()` explaining the two-pass batching strategy differs from the per-batch lazy-reload pattern adopted by GRPO/NFT/AWM/DPO in 3e615ee (X-GenGroup#118). The optimize logic is intentionally left unchanged to avoid potential correctness issues; may be refactored in the future. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Rename examples/crd/lora/sd3_5.yaml → examples/crd/lora/sd3_5/default.yaml
to match the convention: examples/{algorithm}/{finetune_type}/{model_type}/{variant}.yaml
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
I have refactor this PR to solve the issue mentioned above. If everything looks good for you, I will merge. |

Summary
Implements Centered Reward Distillation (CRD) (arXiv:2603.14128) as a new decoupled RL trainer for flow-matching models.
trainers/crd.py: FullCRDTrainer— maintains_crd_oldand_crd_samplingnamed parameter snapshots; dual-direction centering loss (weight_tempmodes); adaptive KL regularization against a CFG-guided pretrained reference; per-step velocity-space implicit reward estimationhparams/training_args.py:CRDTrainingArgumentswith paper-aligned defaults (linear decay schedules,kl_beta=0.1,kl_cfg=4.5,timestep_range=0.99)trainers/registry.py: Register'crd'keyhparams/__init__.py: ExportCRDTrainingArgumentsexamples/crd/lora/sd3_5.yaml: SD3.5 + OCR config matching paper Table 3 (K=24, 2 gradient steps,old_model_decay="0-0.25-0.005-0.999",sampling_model_decay="75-0.0-0.0075-0.999")guidance/algorithms.md: CRD section with hyperparameter reference and centering modes table.agents/knowledge/architecture.md: Add CRD to trainer registry tableTest plan
get_training_args_class('crd')returnsCRDTrainingArgumentsget_trainer_class('crd')loadsCRDTrainer🤖 Generated with Claude Code