Skip to content

Commit b623d1b

Browse files
authored
Docs (#71)
* Initial docs commit * Add extensive docstrings for data types * Add documentation on configs and data * Add documentation for models * Add documentation for orchestrators * Add documentation for pipelines * Resolve missed merge conflict * Add docs for ilql * Add some brief documentation on examples present * update readme with link to docs * Add rtd yml config * Remove unneeded/ugly undoc-members * Update docs for configs to account for method specific configs * Add docstrings for method configs * Move docstring into ModelBranch class * Update docs with pipeline and model refactors * Resolve erroneous merge (use updated dataclass attributes from master) * Remove old file from before merge * Add spacing after docstrings * Update README.md * removed duplicated class method * Removed unneeded whitespace * Add whitespace after docstrings where appropriate * Update readthedocs version to py39 * precommit fixes * Change save_interval to checkpoint_interval in docstring * Remove redundant docs links from readme
1 parent 3633a9c commit b623d1b

File tree

8 files changed

+75
-18
lines changed

8 files changed

+75
-18
lines changed

.readthedocs.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ sphinx:
44
configuration: docs/source/conf.py
55

66
python:
7-
version: 3.8
7+
version: 3.9
88
install:
99
- requirements: docs/requirements.txt

docs/source/configs.rst

-2
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,9 @@ the specific method being used (i.e. ILQL or PPO)
2626
**PPO**
2727

2828
.. autoclass:: trlx.data.method_configs.PPOConfig
29-
:undoc-members:
3029
:members:
3130

3231
**ILQL**
3332

3433
.. autoclass:: trlx.data.method_configs.ILQLConfig
35-
:undoc-members:
3634
:members:

docs/source/data.rst

-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ each demand different kinds of data during training.
1212

1313
**Basic Data Elements for Accelerate**
1414

15-
1615
.. autoclass:: trlx.data.accelerate_base_datatypes.PromptElement
1716
:members:
1817

@@ -25,7 +24,6 @@ each demand different kinds of data during training.
2524
.. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLBatchElement
2625
:members:
2726

28-
2927
**Data Elements for PPO**
3028

3129
.. autoclass:: trlx.data.ppo_types.PPORLElement

docs/source/models.rst

+5-3
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,19 @@ Note that new models must be registered with ``trlx.model.register_model``.
1919
.. autoclass:: trlx.model.accelerate_ppo_model.AcceleratePPOModel
2020
:members:
2121

22-
.. autoclass:: trlx.model.nn.ppo_models.ValueHead
22+
.. autoclass:: trlx.model.nn.ppo_models.GPTHeadWithValueModel
2323
:members:
2424

25-
.. autoclass:: trlx.model.nn.ppo_models.GPT2HeadWithValueModel
25+
.. autoclass:: trlx.model.nn.ppo_models.ModelBranch
26+
:members:
27+
28+
.. autoclass:: trlx.model.nn.ppo_models.GPTHydraHeadWithValueModel
2629
:members:
2730

2831
**ILQL**
2932

3033
.. autoclass:: trlx.model.accelerate_ilql_model.AccelerateILQLModel
3134
:members:
3235

33-
3436
.. autoclass:: trlx.model.nn.ilql_models.CausalLMWithValueHeads
3537
:members:

trlx/data/configs.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class ModelConfig:
2323

2424
model_path: str
2525
tokenizer_path: str
26-
model_type: str # One of the architectures present in framework.model
26+
model_type: str
2727
num_layers_unfrozen: int = -1
2828

2929
@classmethod
@@ -75,6 +75,9 @@ class TrainConfig:
7575
:param orchestrator: Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator
7676
:type orchestrator: str
7777
78+
:param checkpoint_dir: Directory to save checkpoints
79+
:type checkpoint_dir: str
80+
7881
:param project_name: Project name for wandb
7982
:type project_name: str
8083
"""

trlx/data/method_configs.py

+59
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,43 @@ def from_dict(cls, config: Dict[str, Any]):
5959
@dataclass
6060
@register_method
6161
class PPOConfig(MethodConfig):
62+
"""
63+
Config for PPO method
64+
65+
:param ppo_epochs: Number of updates per batch
66+
:type ppo_epochs: int
67+
68+
:param num_rollouts: Number of experiences to observe before learning
69+
:type num_rollouts: int
70+
71+
:param init_kl_coef: Initial value for KL coefficient
72+
:type init_kl_coef: float
73+
74+
:param target: Target value for KL coefficient
75+
:type target: float
76+
77+
:param horizon: Number of steps for KL coefficient to reach target
78+
:type horizon: int
79+
80+
:param gamma: Discount factor
81+
:type gamma: float
82+
83+
:param lam: GAE lambda
84+
:type lam: float
85+
86+
:param cliprange: Clipping range for PPO policy loss (1 - cliprange, 1 + cliprange)
87+
:type cliprange: float
88+
89+
:param cliprange_value: Clipping range for predicted values (observed values - cliprange_value, observed values + cliprange_value)
90+
:type cliprange_value: float
91+
92+
:param vf_coef: Value loss scale w.r.t policy loss
93+
:type vf_coef: float
94+
95+
:param gen_kwargs: Additioanl kwargs for the generation
96+
:type gen_kwargs: Dict[str, Any]
97+
"""
98+
6299
ppo_epochs: int
63100
num_rollouts: int
64101
chunk_size: int
@@ -76,6 +113,28 @@ class PPOConfig(MethodConfig):
76113
@dataclass
77114
@register_method
78115
class ILQLConfig(MethodConfig):
116+
"""
117+
Config for ILQL method
118+
119+
:param tau: Control tradeoff in value loss between punishing value network for underestimating the target Q (i.e. Q value corresponding to the action taken) (high tau) and overestimating the target Q (low tau)
120+
:type tau: float
121+
122+
:param gamma: Discount factor for future rewards
123+
:type gamma: float
124+
125+
:param cql_scale: Weight for CQL loss term
126+
:type cql_scale: float
127+
128+
:param awac_scale: Weight for AWAC loss term
129+
:type awac_scale: float
130+
131+
:param steps_for_target_q_sync: Number of steps to wait before syncing target Q network with Q network
132+
:type steps_for_target_q_sync: int
133+
134+
:param two_qs: Use minimum of two Q-value estimates
135+
:type two_qs: bool
136+
"""
137+
79138
tau: float
80139
gamma: float
81140
cql_scale: float

trlx/data/ppo_types.py

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
@dataclass
77
class PPORLElement:
88
"""
9-
RLElement for PPO
10-
119
:param query_tensor: The query tensor i.e. the prompt tokens. Should be a long tensor.
1210
:type query_tensor: torch.Tensor
1311

trlx/model/nn/ppo_models.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,13 @@ def forward(
9999
)
100100

101101

102-
"""
103-
ModelBranch implements the frozen upper trunk of the reference model
104-
used when computing the PPO KL-divergence penalty. Expects a list of
105-
frozen transformer blocks and an lm_head from the base model.
106-
"""
107-
108-
109102
class ModelBranch(PreTrainedModel):
103+
"""
104+
ModelBranch implements the frozen upper trunk of the reference model
105+
used when computing the PPO KL-divergence penalty. Expects a list of
106+
frozen transformer blocks and an lm_head from the base model.
107+
"""
108+
110109
def __init__(self, config, transformer_blocks, ln_f, lm_head):
111110
super().__init__(config)
112111

0 commit comments

Comments
 (0)