Skip to content

Commit 4e44ff8

Browse files
committed
Update
[ghstack-poisoned]
1 parent 3fb911b commit 4e44ff8

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

.github/unittest/linux_libs/scripts_llm/environment.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ dependencies:
1717
- pyyaml
1818
- scipy
1919
- hydra-core
20-
- transformers<4.42.0
20+
- transformers
2121
- datasets
2222
- vllm

examples/rlhf/models/actor_critic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,4 @@ def init_actor_critic(model_cfg, sys_cfg):
3434
critic = model.get_value_operator()
3535
critic_head = model.get_value_head()
3636

37-
return actor, VmapModule(critic), critic_head, base_model
37+
return actor, VmapModule(critic, mock=True), critic_head, base_model

torchrl/modules/tensordict_module/common.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -436,29 +436,38 @@ class VmapModule(TensorDictModuleBase):
436436
>>> assert (sample_in_td["x"][:, 0] == sample_in_td["y"]).all()
437437
"""
438438

439-
def __init__(self, module: TensorDictModuleBase, vmap_dim=None):
439+
def __init__(self, module: TensorDictModuleBase, vmap_dim=None, mock: bool = False):
440440
if not _has_functorch:
441441
raise ImportError("VmapModule requires torch>=2.0.")
442442
super().__init__()
443443
self.in_keys = module.in_keys
444444
self.out_keys = module.out_keys
445445
self.module = module
446446
self.vmap_dim = vmap_dim
447+
self.mock = mock
447448
if torch.__version__ >= "2.0":
448449
self._vmap = torch.vmap
449450
else:
450451
import functorch
451452

452453
self._vmap = functorch.vmap
453454

455+
def mock_(self, value: bool = True):
456+
self.mock = value
457+
454458
def forward(self, tensordict):
455459
# TODO: there is a risk of segfault if input is not a tensordict.
456460
# We should investigate (possibly prevent it c++ side?)
457461
vmap_dim = self.vmap_dim
458462
if vmap_dim is None:
459463
ndim = tensordict.ndim
460464
vmap_dim = ndim - 1
461-
td = self._vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict)
465+
if self.mock:
466+
td = torch.stack(
467+
[self.module(_td) for _td in tensordict.unbind(vmap_dim)], vmap_dim
468+
)
469+
else:
470+
td = self._vmap(self.module, (vmap_dim,), (vmap_dim,))(tensordict)
462471
return tensordict.update(td)
463472

464473

0 commit comments

Comments
 (0)