From eb7f282a45f12e090aa44b4966c1a4c11ccf64c4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 22 Apr 2025 14:29:33 +0100 Subject: [PATCH 1/5] Update [ghstack-poisoned] --- docs/source/reference/envs.rst | 1 - test/test_actors.py | 593 +----------------- test/test_collector.py | 405 +------------ test/test_cost.py | 82 +-- test/test_env.py | 356 +---------- torchrl/data/llm/__init__.py | 14 +- torchrl/data/llm/chat.py | 347 ----------- torchrl/data/llm/utils.py | 96 +-- torchrl/envs/custom/llm.py | 563 +---------------- torchrl/envs/transforms/llm.py | 507 +--------------- torchrl/modules/__init__.py | 4 - torchrl/modules/llm/__init__.py | 11 - torchrl/modules/llm/common.py | 92 --- torchrl/modules/llm/transformers_wrapper.py | 475 --------------- torchrl/modules/llm/vllm_wrapper.py | 632 -------------------- 15 files changed, 15 insertions(+), 4163 deletions(-) delete mode 100644 torchrl/data/llm/chat.py delete mode 100644 torchrl/modules/llm/__init__.py delete mode 100644 torchrl/modules/llm/common.py delete mode 100644 torchrl/modules/llm/transformers_wrapper.py delete mode 100644 torchrl/modules/llm/vllm_wrapper.py diff --git a/docs/source/reference/envs.rst b/docs/source/reference/envs.rst index e08f97d44d6..e5602d0553f 100644 --- a/docs/source/reference/envs.rst +++ b/docs/source/reference/envs.rst @@ -521,7 +521,6 @@ TorchRL offers a series of custom built-in environments. ChessEnv PendulumEnv TicTacToeEnv - LLMEnv LLMHashingEnv diff --git a/test/test_actors.py b/test/test_actors.py index b1f1687deee..4b4ef563035 100644 --- a/test/test_actors.py +++ b/test/test_actors.py @@ -10,31 +10,15 @@ import pytest import torch -from tensordict import ( - lazy_stack, - LazyStackedTensorDict, - NonTensorStack, - set_list_to_stack, - TensorDict, -) +from tensordict import TensorDict from tensordict.nn import CompositeDistribution, TensorDictModule from tensordict.nn.distributions import NormalParamExtractor from torch import distributions as dist, nn -from torchrl.collectors import SyncDataCollector from torchrl.data import Binary, Bounded, Categorical, Composite, MultiOneHot, OneHot -from torchrl.data.llm import LLMData from torchrl.data.llm.dataset import _has_transformers -from torchrl.envs import LLMEnv -from torchrl.modules import ( - MLP, - SafeModule, - TanhDelta, - TanhNormal, - TransformersWrapper, - vLLMWrapper, -) +from torchrl.modules import MLP, SafeModule, TanhDelta, TanhNormal from torchrl.modules.tensordict_module.actors import ( _process_action_space_spec, ActorValueOperator, @@ -51,10 +35,10 @@ if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import get_default_devices - from pytorch.rl.test.mocking_classes import DummyStrDataLoader, NestedCountingEnv + from pytorch.rl.test.mocking_classes import NestedCountingEnv else: from _utils_internal import get_default_devices - from mocking_classes import DummyStrDataLoader, NestedCountingEnv + from mocking_classes import NestedCountingEnv _has_vllm = importlib.util.find_spec("vllm") is not None @@ -928,575 +912,6 @@ def test_lmhead_actorvalueoperator(device): ) == len(policy_params) -@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") -@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") -class TestLLMActor: - @pytest.fixture(scope="module") - def vllm_instance(self): - try: - import vllm - except ImportError: - pytest.skip(reason="missing vllm") - - llm_model = vllm.LLM("gpt2") - tokenizer = llm_model.get_tokenizer() - tokenizer.pad_token = tokenizer.eos_token - return llm_model - - @pytest.fixture(scope="module") - def transformers_instance(self): - from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = GPT2LMHeadModel(GPT2Config()).eval() - # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - # model = OPTModel(OPTConfig("facebook/opt-125m")) - # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - # model = OPTForCausalLM(OPTConfig()) - - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - - return model, tokenizer - - @pytest.fixture(scope="module") - def transformers_instance_pretrained(self): - from transformers import AutoTokenizer, OPTForCausalLM - - # tokenizer = AutoTokenizer.from_pretrained("gpt2") - # model = GPT2LMHeadModel(GPT2Config()) - # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - # model = OPTModel(OPTConfig("facebook/opt-125m")) - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - model = OPTForCausalLM.from_pretrained("facebook/opt-125m") - - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - - return model, tokenizer - - @pytest.mark.parametrize( - "from_text, generate, return_log_probs, tokens, attention_mask", - [ - (True, True, True, None, None), - (True, True, False, None, None), - (True, False, None, None, None), - ( - False, - True, - True, - torch.randint(1024, (1, 10)), - torch.ones(1, 10, dtype=torch.int64), - ), - (False, True, True, torch.randint(1024, (1, 10)), None), - ( - False, - True, - False, - torch.randint(1024, (1, 10)), - torch.ones(1, 10, dtype=torch.int64), - ), - (False, True, False, torch.randint(1024, (1, 10)), None), - ], - ) - def test_transformers_wrapper( - self, - from_text, - generate, - return_log_probs, - tokens, - attention_mask, - transformers_instance, - ): - torch.manual_seed(0) - - model, tokenizer = transformers_instance - - m = TransformersWrapper( - model, - tokenizer=tokenizer, - from_text=from_text, - generate=generate, - return_log_probs=return_log_probs, - ) - self._run_check( - m, - tokens, - attention_mask, - generate, - return_log_probs, - from_text, - has_logits=True, - ) - - @pytest.mark.parametrize( - "from_text, generate, return_log_probs, tokens, attention_mask", - [ - (True, True, True, None, None), - (True, True, False, None, None), - (True, False, None, None, None), - ( - False, - True, - True, - torch.randint(1024, (1, 10)), - torch.ones(1, 10, dtype=torch.int64), - ), - (False, True, True, torch.randint(1024, (1, 10)), None), - ( - False, - True, - False, - torch.randint(1024, (1, 10)), - torch.ones(1, 10, dtype=torch.int64), - ), - (False, True, False, torch.randint(1024, (1, 10)), None), - ], - ) - def test_vllm_wrapper( - self, - from_text, - generate, - return_log_probs, - tokens, - attention_mask, - vllm_instance, - ): - torch.manual_seed(0) - - model = vllm_instance - m = vLLMWrapper( - model, - from_text=from_text, - generate=generate, - return_log_probs=return_log_probs, - ) - self._run_check( - m, - tokens, - attention_mask, - generate, - return_log_probs, - from_text, - has_logits=False, - ) - - def _make_data( - self, - m, - tokens, - attention_mask, - generate, - from_text, - has_logits, - batch_size=1, - text_response=None, - tokens_response=None, - ): - lp_kwargs = {} - if from_text: - if not generate: - text_response = ( - NonTensorStack(" and another text that follows") - if text_response is None - else text_response - ) - if not isinstance(text_response, NonTensorStack): - if isinstance(text_response, list): - text_response = NonTensorStack(*text_response) - else: - text_response = NonTensorStack(text_response) - lp_kwargs.update({"text_response": text_response}) - tdin = LLMData( - text=NonTensorStack("a text"), **lp_kwargs, batch_size=batch_size - ) - else: - if not generate: - if tokens_response is None: - shape_response = tokens.shape - shape_response = shape_response[:-1] + (shape_response[-1] * 2,) - tokens_response = torch.randint(1024, shape_response) - lp_kwargs.update({"tokens_response": tokens_response}) - tdin = LLMData( - tokens=tokens, - attention_mask=attention_mask, - **lp_kwargs, - batch_size=batch_size, - ) - return tdin - - def _run_check( - self, - m, - tokens, - attention_mask, - generate, - return_log_probs, - from_text, - has_logits, - ): - tdin = self._make_data( - m, tokens, attention_mask, generate, from_text, has_logits - ) - if from_text and generate: - assert tdin.text_response is None - elif from_text and not generate: - assert tdin.text_response is not None - - tdin.copy() - td = m(tdin) - assert td is tdin - assert isinstance(td, LLMData) - if from_text and generate: - assert td.text_response is not None - - # TODO: vLLM may produce an attention mask when hf does not - explore consistency! - # if generate and (from_text or tdincopy.attention_mask is not None): - # assert td.attention_mask is not None, (generate, from_text, tdincopy.attention_mask is not None) - # if isinstance(td.attention_mask, torch.Tensor): - # assert td.attention_mask.shape == td.tokens.shape - # else: - # assert td.attention_mask is None, (generate, from_text) - - if not generate: - # logprobs are computed on text response of tokens_response - assert td.text_response is not None or td.tokens_response is not None - assert td.log_probs is not None - if has_logits: - assert td.logits is not None - if generate: - if return_log_probs: - assert td.log_probs is not None - assert td.log_probs.shape[-1] == td.tokens_response.shape[-1] - else: - assert td.log_probs is None - - # Test the shapes - assert td.tokens_response is not None, (generate, has_logits, from_text) - - # If from text and not generating, the tokens are not returned for now - if not (from_text and not generate): - assert td.tokens_response is not None - assert td.tokens is not None - assert td.tokens_response.shape[:-1] == td.tokens.shape[:-1] - # The convention is that the response only has new tokens - assert ( - td.tokens_response[..., : td.tokens.shape[-1]] - != td.tokens[..., : td.tokens_response.shape[-1]] - ).any(), (generate, from_text) - - @pytest.mark.parametrize( - "from_text, tokens, attention_mask", - [ - ( - False, - torch.randint(1024, (1, 10)), - torch.ones(1, 10, dtype=torch.int64), - ), - (False, torch.randint(1024, (1, 10)), None), - (True, None, None), - ], - ) - def test_transformers_logprobs( - self, from_text, tokens, attention_mask, transformers_instance - ): - torch.manual_seed(0) - model, tokenizer = transformers_instance - - m_generate = TransformersWrapper( - model, - tokenizer=tokenizer, - from_text=from_text, - generate=True, - return_log_probs=True, - ) - m_logprobs = TransformersWrapper( - model, tokenizer=tokenizer, from_text=from_text, generate=False - ) - self._check_lps( - m_generate, m_logprobs, tokens, attention_mask, from_text, has_logits=False - ) - - @pytest.mark.parametrize( - "pad_output, from_text, tokens, attention_mask", - [ - (True, True, None, None), - (False, True, None, None), - ( - True, - False, - torch.randint(1024, (1, 10)), - torch.ones(1, 10, dtype=torch.int64), - ), - (True, False, torch.randint(1024, (1, 10)), None), - ], - ) - def test_vllm_logprobs( - self, from_text, tokens, attention_mask, pad_output, vllm_instance - ): - torch.manual_seed(0) - - model = vllm_instance - m_generate = vLLMWrapper( - model, - from_text=from_text, - generate=True, - return_log_probs=True, - pad_output=pad_output, - ) - m_logprobs = vLLMWrapper( - model, from_text=from_text, generate=False, pad_output=pad_output - ) - self._check_lps( - m_generate, - m_logprobs, - tokens, - attention_mask, - from_text, - has_logits=False, - tol=1e-1, - ) - - def _check_lps( - self, - model_generate, - model_logprobs, - tokens, - attention_mask, - from_text, - has_logits, - tol=1e-2, - ): - # Checks that the log-probs gathered with generate=False equate those with generate=True - tdin_genetate = self._make_data( - model_generate, tokens, attention_mask, True, from_text, has_logits - ) - td_generate = model_generate(tdin_genetate) - tdin_logprobs = self._make_data( - model_logprobs, - tokens, - attention_mask, - False, - from_text, - has_logits, - tokens_response=td_generate.tokens_response, - text_response=td_generate.text_response, - ) - td_logprobs = model_logprobs(tdin_logprobs) - assert td_generate.log_probs.shape == td_generate.tokens_response.shape - assert td_logprobs.log_probs.shape == td_logprobs.tokens_response.shape - assert td_logprobs.log_probs.shape == td_generate.tokens_response.shape - torch.testing.assert_close( - td_generate.log_probs, td_logprobs.log_probs, rtol=tol, atol=tol - ) - - @pytest.mark.parametrize("pad", [True, False]) - @pytest.mark.parametrize("generate", [True, False]) - @pytest.mark.parametrize("use_tensorclass", [True, False]) - def test_vllm_batch_run(self, pad, generate, use_tensorclass, vllm_instance): - # Test generate - padding combinations - policy = vLLMWrapper( - vllm_instance, - from_text=True, - generate=generate, - return_log_probs=True, - pad_output=pad, - generate_kwargs={"max_tokens": 10000}, - ) - if generate: - data = LazyStackedTensorDict( - *TensorDict( - text=NonTensorStack("a string", "another very long string"), - batch_size=[2], - ).unbind(0) - ) - else: - data = LazyStackedTensorDict( - *TensorDict( - text=NonTensorStack("a string", "another very long string"), - text_response=NonTensorStack( - " is a string", " is still a very long string" - ), - batch_size=[2], - ).unbind(0) - ) - if use_tensorclass: - data = LLMData.from_tensordict(data) - output = policy(data) - try: - log_probs = output.get("log_probs") - except Exception: - log_probs = output.get("log_probs", as_list=True) - if pad: - assert isinstance(log_probs, torch.Tensor) - else: - assert isinstance(log_probs, list) - text = output.get("text", as_list=True) - # TODO: this is not ideal... - if use_tensorclass: - assert isinstance(text, list) - else: - assert isinstance(text, NonTensorStack) - text_response = output.get("text_response", as_list=True) - if use_tensorclass: - assert isinstance(text_response, list) - else: - assert isinstance(text_response, NonTensorStack) - try: - tokens_response = output.get("tokens_response") - except Exception: - tokens_response = output.get("tokens_response", as_list=True) - if pad: - assert isinstance(tokens_response, torch.Tensor) - else: - assert isinstance(tokens_response, list) - try: - tokens = output.get("tokens") - except Exception: - tokens = output.get("tokens", as_list=True) - if not generate: - assert tokens is None - elif pad: - assert isinstance(tokens, torch.Tensor), tokens - else: - assert isinstance(tokens, list) - - @pytest.mark.parametrize("from_text", [True]) - def test_vllm_collection(self, vllm_instance, from_text): - policy = vLLMWrapper( - vllm_instance, - return_log_probs=True, - generate_kwargs={"max_tokens": 32}, - from_text=from_text in (True, None), - ) - tokenizer = vllm_instance.get_tokenizer() - self._run_check_collector(policy, from_text=from_text, tokenizer=tokenizer) - - def test_transformers_collection(self): - ... - - @classmethod - def env_constructor(cls, **kwargs): - def make(): - # if kwargs.get("from_text", True): - dl = DummyStrDataLoader(batch_size=32) - # else: - # dl = DummyTensorDataLoader(batch_size=32) - env = LLMEnv.from_dataloader( - dl, - batch_size=4, - repeats=4, - **kwargs, - ) - assert env.batch_size == (16,) - return env - - return make - - def _run_check_collector(self, policy, from_text, tokenizer): - if from_text is None: - kwargs = {"eos_token_id": tokenizer.eos_token_id} - else: - kwargs = { - "from_text": from_text, - "tokenizer": tokenizer, - "eos_token_id": tokenizer.eos_token_id, - } - collector = SyncDataCollector( - self.env_constructor(**kwargs), - policy=policy, - frames_per_batch=32, - total_frames=128, - use_buffers=False, - ) - t = 0 - for data in collector: - assert isinstance(data, LazyStackedTensorDict) - assert isinstance(data.reshape(-1).get("text_response"), NonTensorStack) - # action - assert "text_response" in data - assert "tokens_response" in data - # obs - assert "text" in data - assert ("next", "text") in data - # tokens - assert "tokens" in data - - t += data.numel() - assert collector._frames == t - assert t < 512, t - # assert ("next", "tokens") in data - - def test_vllm_generate_multiple_trajs(self, vllm_instance): - policy = vLLMWrapper( - vllm_instance, - return_log_probs=True, - generate_kwargs={"n": 10, "max_tokens": 1024}, - inplace=False, - ) - data = TensorDict( - text=NonTensorStack("a string", "another very long string"), batch_size=2 - ) - data = policy(data) - - @set_list_to_stack(True) - @pytest.mark.parametrize("from_text", [True, False]) - @pytest.mark.parametrize("generate", [True, False]) - def test_transformers_long_sequences( - self, from_text, generate, transformers_instance_pretrained - ): - torch.manual_seed(42) - model, tokenizer = transformers_instance_pretrained - prompts = [ - "The quick brown fox jumps over the lazy dog.", # Likely to finish soon - "Once upon a time in a land far, far away, there was a", # Likely to continue longer - "In the beginning, the universe was created. This has made a lot of people very angry and been widely regarded as a bad move.", - ] - data = lazy_stack([TensorDict() for _ in range(len(prompts))]) - data["text"] = prompts - eos_token_id = tokenizer.convert_tokens_to_ids(",") - if not from_text: - data["tokens"] = tokenizer(data["text"])["input_ids"] - data["attention_mask"] = ( - 0 * data.get("tokens", as_nested_tensor=True, layout=torch.strided) + 1 - ) - if not generate: - # we need responses - responses = prompts[1:] + [" et dolore magna aliqua."] - data["text_response"] = responses - if not from_text: - data["tokens_response"] = tokenizer(data["text_response"])["input_ids"] - # make sure dimensions are ragged for tokens entries - if "tokens" in data: - assert data.get_item_shape("tokens")[-1] == -1 - if "tokens_response" in data: - assert data.get_item_shape("tokens_response")[-1] == -1 - generate_kwargs = {} - if generate: - generate_kwargs = { - "max_new_tokens": 128, # Set a reasonable number of new tokens to generate - "min_length": 20, # Ensure a minimum length for the generated sequence - "pad_token_id": tokenizer.pad_token_id, # Use the tokenizer's pad token - "forced_eos_token_id": eos_token_id, # Use comma as an EOS token - } - policy = TransformersWrapper( - model, - tokenizer=tokenizer, - from_text=from_text, - generate=generate, - return_log_probs=True, - # TODO: use n trajs - generate_kwargs=generate_kwargs, - ) - data_policy = policy(data) - if "tokens" in data_policy: - assert data_policy.get_item_shape("tokens")[-1] == -1 - if "tokens_response" in data_policy: - assert ( - data_policy.get_item_shape("tokens_response")[-1] == -1 - ) # TODO: this fails - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_collector.py b/test/test_collector.py index fdd57a6be99..249c7ecf9d2 100644 --- a/test/test_collector.py +++ b/test/test_collector.py @@ -5,11 +5,9 @@ from __future__ import annotations import argparse -import asyncio import contextlib import functools import gc -import importlib import os import subprocess import sys @@ -52,26 +50,21 @@ MultiSyncDataCollector, ) -from torchrl.collectors.llm import LLMCollector from torchrl.collectors.utils import split_trajectories from torchrl.data import ( Composite, LazyMemmapStorage, - LazyStackStorage, LazyTensorStorage, NonTensor, ReplayBuffer, TensorSpec, Unbounded, ) -from torchrl.data.llm.dataset import _has_transformers from torchrl.data.utils import CloudpickleWrapper from torchrl.envs import ( - AsyncEnvPool, EnvBase, EnvCreator, InitTracker, - LLMEnv, ParallelEnv, SerialEnv, StepCounter, @@ -85,13 +78,7 @@ PARTIAL_MISSING_ERR, RandomPolicy, ) -from torchrl.modules import ( - Actor, - OrnsteinUhlenbeckProcessModule, - SafeModule, - TransformersWrapper, - vLLMWrapper, -) +from torchrl.modules import Actor, OrnsteinUhlenbeckProcessModule, SafeModule if os.getenv("PYTORCH_TEST_FBCODE"): IS_FB = True @@ -116,7 +103,6 @@ DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, - DummyStrDataLoader, EnvThatErrorsAfter10Iters, EnvWithDynamicSpec, HeterogeneousCountingEnv, @@ -149,7 +135,6 @@ DiscreteActionConvPolicy, DiscreteActionVecMockEnv, DiscreteActionVecPolicy, - DummyStrDataLoader, EnvThatErrorsAfter10Iters, EnvWithDynamicSpec, HeterogeneousCountingEnv, @@ -167,7 +152,6 @@ PYTHON_3_7 = sys.version_info.major == 3 and sys.version_info.minor == 7 TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version) _has_cuda = torch.cuda.is_available() -_has_vllm = importlib.util.find_spec("vllm") is not None class WrappablePolicy(nn.Module): @@ -3561,393 +3545,6 @@ def test_weight_update(self): collector.shutdown() -@pytest.mark.skipif(not _has_transformers, reason="missing transformers dependencies") -@pytest.mark.skipif(not _has_vllm, reason="missing vllm dependencies") -class TestLLMCollector: - @pytest.fixture(scope="module") - def vllm_instance(self): - try: - import vllm - except ImportError: - pytest.skip(reason="missing vllm") - - llm_model = vllm.LLM("gpt2") - tokenizer = llm_model.get_tokenizer() - tokenizer.pad_token = tokenizer.eos_token - return llm_model - - @pytest.fixture(scope="module") - def vllm_instance_opt(self): - try: - import vllm - except ImportError: - pytest.skip(reason="missing vllm") - - llm_model = vllm.LLM("facebook/opt-125m") - tokenizer = llm_model.get_tokenizer() - tokenizer.pad_token = tokenizer.eos_token - return llm_model - - @pytest.fixture(scope="module") - def transformers_instance(self): - from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel - - tokenizer = AutoTokenizer.from_pretrained("gpt2") - model = GPT2LMHeadModel(GPT2Config()).eval() - # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - # model = OPTModel(OPTConfig("facebook/opt-125m")) - # tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - # model = OPTForCausalLM(OPTConfig()) - - tokenizer.pad_token = tokenizer.eos_token - tokenizer.padding_side = "left" - - return model, tokenizer - - @pytest.mark.slow - @pytest.mark.parametrize("rb", [True, False]) - @pytest.mark.parametrize("total_steps", [1, 10, 20]) - def test_llm_collector_with_vllm(self, rb, total_steps, vllm_instance): - # NOTE: if VLLM fails with CUDA multiprocessing, try setting - # `export VLLM_WORKER_MULTIPROC_METHOD=spawn` - policy = vLLMWrapper(vllm_instance) - tokenizer = vllm_instance.get_tokenizer() - self._run_collector_test(total_steps, rb, policy, tokenizer) - - @pytest.mark.slow - @pytest.mark.parametrize("rb", [True, False]) - @pytest.mark.parametrize("total_steps", [1, 10, 20]) - def test_llm_collector_with_transformers( - self, rb, total_steps, transformers_instance - ): - model, tokenizer = transformers_instance - policy = TransformersWrapper( - model, - tokenizer=tokenizer, - from_text=True, - generate=True, - return_log_probs=True, - ) - self._run_collector_test(total_steps, rb, policy, tokenizer) - - def _run_collector_test(self, total_steps, rb, policy, tokenizer): - bsz = 4 - dataloader = DummyStrDataLoader(bsz) - - env = LLMEnv.from_dataloader( - dataloader=dataloader, - from_text=True, - batch_size=bsz, - group_repeats=True, - eos_token_id=tokenizer.eos_token_id, - ) - if rb: - rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2)) - else: - rb = None - collector = LLMCollector( - env=env, - policy_factory=lambda: policy, - steps_per_batch=env.batch_size[0], - replay_buffer=rb, - total_steps=total_steps, - ) - - stack = [] - for data in collector: - # Should be moved to replay buffer - if rb is not None: - assert data is None - else: - stack.append(data) - - if rb is not None: - # Now check the buffer - assert len(rb) >= total_steps - sample = rb.sample(4) - assert sample.shape == (4,) - assert not sample._has_exclusive_keys - # Should match length - assert len(sample["text"]) == 4 - # assert len(sample["text"][0]) == 10, sample["text"][0] - # Should be non-empty - assert sample["text_response"] is not None - for i in range(4): - # Check that there are more chars in the next step - assert len(sample["text"][i]) < len(sample["next", "text"][i]) - else: - stack = torch.cat(stack) - assert not stack._has_exclusive_keys - assert stack.numel() == max(-(total_steps // -4) * 4, 4) - stack = stack.view(-1) - for i in range(stack.numel()): - # Check that there are more chars in the next step - assert len(stack["text"][i]) < len(stack["next", "text"][i]) - assert collector._frames >= total_steps - - @pytest.mark.slow - @pytest.mark.asyncio - async def test_llm_collector_start(self, vllm_instance): - total_steps = 20 - policy = vLLMWrapper(vllm_instance) - vllm_instance.get_tokenizer() - bsz = 4 - dataloader = DummyStrDataLoader(bsz) - - env = LLMEnv.from_dataloader( - dataloader=dataloader, - from_text=True, - batch_size=bsz, - group_repeats=True, - ) - - rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2)) - collector = LLMCollector( - env=env, - policy_factory=lambda: policy, - steps_per_batch=env.batch_size[0], - replay_buffer=rb, - total_steps=total_steps, - ) - torchrl_logger.info("starting") - collector.start() - - j = 0 - while True: - if not len(rb): - await asyncio.sleep(1) # Use asyncio.sleep instead of time.sleep - sample = rb.sample(10) - assert sample.ndim == 1 - for i in range(10): - # Check that there are more chars in the next step - assert len(sample["text"][i]) < len(sample["next", "text"][i]) - assert not sample._has_exclusive_keys, sample - j += 1 - if j == 5: - break - assert collector._frames >= total_steps - - try: - # Assuming collector._task is the task created in start() - await asyncio.wait_for(collector.async_shutdown(), timeout=30) - except asyncio.TimeoutError: - torchrl_logger.info("Collector shutdown timed out") - - @pytest.mark.slow - @pytest.mark.parametrize("rb", [False, True]) - @pytest.mark.parametrize("yield_only_last_steps", [False, True]) - def test_llm_collector_completed( - self, vllm_instance_opt, rb, yield_only_last_steps - ): - torch.manual_seed(0) - policy = vLLMWrapper(vllm_instance_opt) - tokenizer = vllm_instance_opt.get_tokenizer() - bsz = 4 - total_steps = 20 - max_steps = 20 - dataloader = DummyStrDataLoader(bsz) - - env = LLMEnv.from_dataloader( - dataloader=dataloader, - from_text=True, - batch_size=bsz, - group_repeats=True, - eos_token_id=tokenizer.eos_token_id, - ) - # To make sure the env breaks at some point - env = env.append_transform(StepCounter(max_steps=max_steps)) - - if rb: - rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2)) - else: - rb = None - collector = LLMCollector( - env=env, - policy_factory=lambda: policy, - steps_per_batch=env.batch_size[0], - replay_buffer=rb, - total_steps=total_steps, - yield_completed_trajectories=True, - yield_only_last_steps=yield_only_last_steps, - ) - assert collector.yield_completed_trajectories - assert collector.yield_only_last_steps is yield_only_last_steps - - cur_total_steps = 0 - has_found_one_with_more_steps = False - for data in collector: - if rb is None: - assert data.ndim == 1 - # assert (data["next", "step_count"] < max_steps-1).all() - cur_total_steps += data.numel() - for i in range(data.numel()): - if data[i]["next", "step_count"] == max_steps: - continue - if data[i]["text_response"]: - # Check that there are more chars in the next step - assert len(data["text"][i]) < len(data["next", "text"][i]), ( - i, - data[i]["next", "step_count"], - data[i]["next", "done"], - data[i]["text_response"], - ) - else: - assert len(data["text"][i]) == len(data["next", "text"][i]), ( - i, - data[i]["next", "step_count"], - data[i]["next", "done"], - data[i]["text_response"], - ) - - if yield_only_last_steps: - assert data.shape == (1,) - else: - has_found_one_with_more_steps |= data.numel() > 1 - else: - assert data is None - sample = rb.sample(5) - for i in range(sample.numel()): - if sample[i]["next", "step_count"] == max_steps: - continue - if sample[i]["text_response"]: - # Check that there are more chars in the next step - assert len(sample["text"][i]) < len( - sample["next", "text"][i] - ), ( - i, - sample[i]["next", "step_count"], - sample[i]["next", "done"], - sample[i]["text_response"], - ) - else: - assert len(sample["text"][i]) == len( - sample["next", "text"][i] - ), ( - i, - sample[i]["next", "step_count"], - sample[i]["next", "done"], - sample[i]["text_response"], - ) - - assert sample.ndim == 1 - assert sample.shape == (5,) - assert (sample["next", "step_count"] < 99).all() - cur_total_steps += 1 - assert collector._frames >= cur_total_steps - if rb is None and not yield_only_last_steps: - assert has_found_one_with_more_steps - assert collector._frames >= total_steps - - @pytest.mark.slow - @pytest.mark.parametrize("rb", [False, True]) - @pytest.mark.parametrize("yield_only_last_steps", [False, True]) - def test_llm_collector_completed_async( - self, vllm_instance_opt, rb, yield_only_last_steps - ): - torch.manual_seed(0) - policy = vLLMWrapper(vllm_instance_opt) - tokenizer = vllm_instance_opt.get_tokenizer() - bsz = 4 - total_steps = 20 - max_steps = 20 - dataloader = DummyStrDataLoader(bsz) - - def env_maker(): - env = LLMEnv.from_dataloader( - dataloader=dataloader, - from_text=True, - batch_size=(), - group_repeats=True, - eos_token_id=tokenizer.eos_token_id, - ) - # To make sure the env breaks at some point - env = env.append_transform(StepCounter(max_steps=max_steps)) - return env - - env = AsyncEnvPool([env_maker] * bsz, backend="threading", stack="lazy") - - if rb: - rb = ReplayBuffer(storage=LazyStackStorage(max_size=total_steps * 2)) - else: - rb = None - collector = LLMCollector( - env=env, - policy_factory=lambda: policy, - steps_per_batch=env.batch_size[0], - replay_buffer=rb, - total_steps=total_steps, - yield_completed_trajectories=True, - yield_only_last_steps=yield_only_last_steps, - ) - assert collector.yield_completed_trajectories - assert collector.yield_only_last_steps is yield_only_last_steps - - cur_total_steps = 0 - has_found_one_with_more_steps = False - for data in collector: - if rb is None: - assert data.ndim == 1 - # assert (data["next", "step_count"] < max_steps-1).all() - cur_total_steps += data.numel() - for i in range(data.numel()): - if data[i]["next", "step_count"] == max_steps: - continue - if data[i]["text_response"]: - # Check that there are more chars in the next step - assert len(data["text"][i]) < len(data["next", "text"][i]), ( - i, - data[i]["next", "step_count"], - data[i]["next", "done"], - data[i]["text_response"], - ) - else: - assert len(data["text"][i]) == len(data["next", "text"][i]), ( - i, - data[i]["next", "step_count"], - data[i]["next", "done"], - data[i]["text_response"], - ) - - if yield_only_last_steps: - assert data.shape == (1,) - else: - has_found_one_with_more_steps |= data.numel() > 1 - else: - assert data is None - sample = rb.sample(5) - for i in range(sample.numel()): - if sample[i]["next", "step_count"] == max_steps: - continue - if sample[i]["text_response"]: - # Check that there are more chars in the next step - assert len(sample["text"][i]) < len( - sample["next", "text"][i] - ), ( - i, - sample[i]["next", "step_count"], - sample[i]["next", "done"], - sample[i]["text_response"], - ) - else: - assert len(sample["text"][i]) == len( - sample["next", "text"][i] - ), ( - i, - sample[i]["next", "step_count"], - sample[i]["next", "done"], - sample[i]["text_response"], - ) - - assert sample.ndim == 1 - assert sample.shape == (5,) - assert (sample["next", "step_count"] < 99).all() - cur_total_steps += 1 - assert collector._frames >= cur_total_steps - if rb is None and not yield_only_last_steps: - assert has_found_one_with_more_steps - assert collector._frames >= total_steps - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_cost.py b/test/test_cost.py index 771c57b3e4a..758dce5ef01 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -41,7 +41,7 @@ ) from tensordict.nn.distributions.composite import _add_suffix from tensordict.nn.utils import Buffer -from tensordict.utils import set_capture_non_tensor_stack, unravel_key +from tensordict.utils import unravel_key from torch import autograd, nn from torchrl._utils import _standardize from torchrl.data import Bounded, Categorical, Composite, MultiOneHot, OneHot, Unbounded @@ -147,10 +147,7 @@ get_available_devices, get_default_devices, ) - from pytorch.rl.test.mocking_classes import ( - ContinuousActionConvMockEnv, - DummyStrDataLoader, - ) + from pytorch.rl.test.mocking_classes import ContinuousActionConvMockEnv else: from _utils_internal import ( # noqa _call_value_nets, @@ -158,7 +155,7 @@ get_available_devices, get_default_devices, ) - from mocking_classes import ContinuousActionConvMockEnv, DummyStrDataLoader + from mocking_classes import ContinuousActionConvMockEnv _has_functorch = True try: @@ -16675,79 +16672,6 @@ def forward(self, td, mode): assert exploration_type() == ExplorationType.RANDOM -class TestPPO4LLMs: - @pytest.mark.skipif( - not _has_transformers, reason="transformers lib required to test PPO with LLMs" - ) - @set_capture_non_tensor_stack(False) - @pytest.mark.parametrize("from_text", [True, False]) - def test_hf(self, from_text): - from torchrl.envs import LLMEnv, Transform - from torchrl.modules import TransformersWrapper - from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM - - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") - tokenizer.pad_token = tokenizer.eos_token - - model = OPTForCausalLM(OPTConfig()).eval() - policy_inference = TransformersWrapper( - model, - tokenizer=tokenizer, - generate=True, - from_text=from_text, - return_log_probs=True, - ) - policy_train = TransformersWrapper( - model, tokenizer=tokenizer, generate=False, from_text=False - ) - for p in policy_train.parameters(): - assert p.requires_grad - # Create some fake data - dl = DummyStrDataLoader(batch_size=32) - llm_env = LLMEnv.from_dataloader( - dl, - tokenizer=tokenizer if not from_text else None, - batch_size=(32,), - from_text=True, - eos_token_id=tokenizer.eos_token_id, - ) - - class RewardTransform(Transform): - def _step(self, td, next_td): - next_td["reward"] = torch.randn_like( - td["tokens_response"], dtype=torch.float - ).unsqueeze(-1) - return next_td - - def transform_reward_spec(self, reward_spec): - return reward_spec.set( - "reward", Unbounded((*reward_spec.shape, -1, 1), dtype=torch.float) - ) - - llm_env = llm_env.append_transform(RewardTransform()) - with torch.no_grad(): - data = llm_env.rollout(3, policy_inference) - data = data.view(-1) - assert data["tokens_response"].shape[-1] == 20 - # Make some fake advantages: - data["advantage"] = torch.randn_like(data["next", "reward"]) - - loss = ClipPPOLoss( - actor_network=policy_train, - ) - loss_vals = loss(data) - - assert "loss_objective" in loss_vals - assert "loss_entropy" in loss_vals - assert loss_vals["loss_objective"].requires_grad - assert loss_vals["loss_entropy"].requires_grad - assert "clip_fraction" in loss_vals - assert "kl_approx" in loss_vals - assert "entropy" in loss_vals - assert "ESS" in loss_vals - assert "loss_critic" not in loss_vals - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) diff --git a/test/test_env.py b/test/test_env.py index a9edc976335..c9a8472f6dc 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -34,7 +34,7 @@ TensorDictBase, ) from tensordict.nn import TensorDictModuleBase -from tensordict.tensorclass import NonTensorData, NonTensorStack, TensorClass +from tensordict.tensorclass import NonTensorStack, TensorClass from tensordict.utils import _unravel_key_to_tuple from torch import nn @@ -47,11 +47,9 @@ CatTensors, ChessEnv, ConditionalSkip, - DataLoadingPrimer, DoubleToFloat, EnvBase, EnvCreator, - LLMEnv, LLMHashingEnv, ParallelEnv, PendulumEnv, @@ -64,7 +62,6 @@ from torchrl.envs.libs.dm_control import _has_dmc, DMControlEnv from torchrl.envs.libs.gym import _has_gym, gym_backend, GymEnv, GymWrapper from torchrl.envs.transforms import Compose, StepCounter, TransformedEnv -from torchrl.envs.transforms.llm import as_padded_tensor from torchrl.envs.transforms.transforms import ( AutoResetEnv, AutoResetTransform, @@ -136,8 +133,6 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, - DummyStrDataLoader, - DummyTensorDataLoader, EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, @@ -179,8 +174,6 @@ DiscreteActionConvMockEnvNumpy, DiscreteActionVecMockEnv, DummyModelBasedEnvBase, - DummyStrDataLoader, - DummyTensorDataLoader, EnvThatDoesNothing, EnvWithDynamicSpec, EnvWithMetadata, @@ -4585,353 +4578,6 @@ def test_skipping_history_env_collector(self, device_env, collector_cls): count += 1 -class TestLLMEnv: - @pytest.fixture(scope="class", autouse=True) - def set_capture(self): - with set_capture_non_tensor_stack(False): - yield None - return - - @pytest.mark.skipif(not _has_transformers, reason="test requires transformers") - @pytest.mark.parametrize( - "from_text,stack_method", - [ - [True, None], - [False, "as_padded_tensor"], - # TODO: a bit experimental, fails with check_env_specs - # [False, "as_nested_tensor"], - [False, None], - ], - ) - @pytest.mark.parametrize("dl_batch_size", [1, 4]) - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) - @pytest.mark.parametrize("device", [None, "cpu"]) - def test_llm_env( - self, from_text, stack_method, device, dl_batch_size, env_batch_size - ): - if from_text: - primer = DataLoadingPrimer( - dataloader=DummyStrDataLoader(batch_size=dl_batch_size), - batch_size=env_batch_size, - ) - else: - if stack_method is None: - stack_method = as_padded_tensor - primer = DataLoadingPrimer( - dataloader=DummyTensorDataLoader( - batch_size=dl_batch_size, padding=True - ), - stack_method=stack_method, - batch_size=env_batch_size, - ) - with pytest.warns(UserWarning, match="eos_token_id"): - env = LLMEnv( - from_text=from_text, - device=device, - batch_size=primer.batch_size, - ) - env = env.append_transform(primer) - if env_batch_size is None: - assert env.batch_size == torch.Size((dl_batch_size,)) - else: - if not isinstance(env_batch_size, tuple): - env_batch_size = ( - torch.Size(()) - if env_batch_size == 0 - else torch.Size((env_batch_size,)) - ) - assert env.batch_size == env_batch_size - - env.check_env_specs(break_when_any_done="both") - - @pytest.mark.skipif(not _has_transformers, reason="test requires transformers") - @pytest.mark.parametrize("tokenizer", [True, False]) - @pytest.mark.parametrize( - "from_text,stack_method", - [ - [True, None], - [False, "as_padded_tensor"], - [False, None], - ], - ) - @pytest.mark.parametrize("device", [None, "cpu"]) - @pytest.mark.parametrize("dl_batch_size", [1, 4]) - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) - def test_llm_from_dataloader( - self, - from_text, - stack_method, - device, - dl_batch_size, - env_batch_size, - tokenizer, - ): - from transformers import AutoTokenizer - - if tokenizer and from_text: - tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") - else: - tokenizer = None - if from_text: - kwargs = { - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), - } - else: - if stack_method is None: - stack_method = as_padded_tensor - kwargs = { - "dataloader": DummyTensorDataLoader( - padding=True, batch_size=dl_batch_size - ), - "stack_method": stack_method, - } - kwargs.update( - { - "batch_size": env_batch_size, - "from_text": from_text, - "device": device, - "has_attention": False, - "tokenizer": tokenizer, - } - ) - with pytest.warns(UserWarning, match="eos_token_id"): - env = LLMEnv.from_dataloader(**kwargs) - if env_batch_size is None: - assert env.batch_size == torch.Size((dl_batch_size,)) - else: - if not isinstance(env_batch_size, tuple): - env_batch_size = ( - torch.Size(()) - if env_batch_size == 0 - else torch.Size((env_batch_size,)) - ) - assert env.batch_size == env_batch_size - env.check_env_specs(break_when_any_done="both") - - def policy(td): - if from_text and tokenizer is None: - if not td.shape: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorData( - "", device=device - ) - else: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( - *[ - NonTensorData("", device=device) - for _ in range(td.shape[0]) - ] - ) - else: - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( - td.shape + (1,), dtype=torch.int64 - ) - return td - - r = env.rollout(10, policy) - if env.batch_size == (): - assert r.ndim == 1 - r = r.unsqueeze(0) - else: - assert r.ndim == 2 - if from_text and tokenizer is None: - assert isinstance(r[0, 0][LLMEnv._DEFAULT_STR_KEY], str) - assert isinstance(r[0, 1][LLMEnv._DEFAULT_STR_KEY], str) - assert ( - r[0, 0][LLMEnv._DEFAULT_STR_KEY] - == r[0, 1][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ), ( - r[0, 0][LLMEnv._DEFAULT_STR_KEY], - r[0, 0][LLMEnv._DEFAULT_ACTION_STR_KEY], - r[0, 0]["next", LLMEnv._DEFAULT_STR_KEY], - r[0, 1][LLMEnv._DEFAULT_STR_KEY], - ) - assert ( - r[0, 1][LLMEnv._DEFAULT_STR_KEY] - == r[0, 2][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[0, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - assert ( - r[-1, 0][LLMEnv._DEFAULT_STR_KEY] - == r[-1, 1][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[-1, 0][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - assert ( - r[-1, 1][LLMEnv._DEFAULT_STR_KEY] - == r[-1, 2][LLMEnv._DEFAULT_STR_KEY][ - : -len(r[-1, 1][LLMEnv._DEFAULT_ACTION_STR_KEY]) - ] - ) - elif tokenizer is None: - assert ( - r[0, 0][LLMEnv._DEFAULT_TOKEN_KEY] - == r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] - ).all() - assert ( - r[0, 1][LLMEnv._DEFAULT_TOKEN_KEY] - == r[0, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] - ).all() - assert ( - r[-1, 0][LLMEnv._DEFAULT_TOKEN_KEY] - == r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY][:-1] - ).all() - assert ( - r[-1, 1][LLMEnv._DEFAULT_TOKEN_KEY] - == r[-1, 2][LLMEnv._DEFAULT_TOKEN_KEY][:-1] - ).all() - - @pytest.mark.parametrize( - "from_text,stack_method", - [ - [True, None], - [False, "as_padded_tensor"], - # TODO: a bit experimental, fails with check_env_specs - # [False, "as_nested_tensor"], - [False, None], - ], - ) - @pytest.mark.parametrize("device", [None, "cpu"]) - @pytest.mark.parametrize("dl_batch_size", [1, 4]) - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) - @pytest.mark.parametrize("repeats", [3]) - def test_llm_from_dataloader_repeats( - self, from_text, stack_method, device, env_batch_size, dl_batch_size, repeats - ): - if from_text: - kwargs = { - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), - "repeats": repeats, - } - else: - if stack_method is None: - stack_method = as_padded_tensor - kwargs = { - "dataloader": DummyTensorDataLoader( - padding=True, batch_size=dl_batch_size - ), - "stack_method": stack_method, - "repeats": repeats, - } - kwargs.update( - { - "batch_size": env_batch_size, - "from_text": from_text, - "device": device, - "has_attention": False, - } - ) - with pytest.warns(UserWarning, match="eos_token_id"): - env = LLMEnv.from_dataloader(**kwargs) - assert env.transform.repeats == repeats - - max_steps = 3 - env.append_transform(StepCounter(max_steps=max_steps)) - - def policy(td): - if from_text: - if not td.shape: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = "" - else: - td[LLMEnv._DEFAULT_ACTION_STR_KEY] = NonTensorStack( - *["" for _ in range(td.shape[0])] - ) - else: - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( - td.shape + (1,), dtype=torch.int64 - ) - return td - - r = env.rollout(100, policy, break_when_any_done=False) - # check that r at reset is always the same - r_reset = r[..., ::max_steps] - if from_text: - all_strings = r_reset.view(-1)[LLMEnv._DEFAULT_STR_KEY] - assert sum(s == all_strings[0] for s in all_strings) == repeats - assert sum(s == all_strings[repeats] for s in all_strings) == repeats - assert sum(s == all_strings[repeats * 2] for s in all_strings) == repeats - else: - all_tokens = r_reset.view(-1)[LLMEnv._DEFAULT_TOKEN_KEY] - assert sum((s == all_tokens[0]).all() for s in all_tokens) == repeats - assert sum((s == all_tokens[repeats]).all() for s in all_tokens) == repeats - assert ( - sum((s == all_tokens[repeats * 2]).all() for s in all_tokens) == repeats - ) - - @pytest.mark.parametrize( - "from_text,stack_method", - [ - [True, None], - [False, "as_padded_tensor"], - ], - ) - @pytest.mark.parametrize("device", [None]) - @pytest.mark.parametrize("dl_batch_size", [1, 4]) - @pytest.mark.parametrize("env_batch_size", [None, 0, (), 4]) - @pytest.mark.parametrize("repeats", [3]) - @pytest.mark.parametrize( - "assign_reward,assign_done", [[True, False], [True, True], [False, True]] - ) - def test_done_and_reward( - self, - from_text, - stack_method, - device, - env_batch_size, - dl_batch_size, - repeats, - assign_reward, - assign_done, - ): - with pytest.raises( - ValueError, match="from_text" - ) if from_text else contextlib.nullcontext(): - if from_text: - kwargs = { - "dataloader": DummyStrDataLoader(batch_size=dl_batch_size), - "repeats": repeats, - "assign_reward": assign_reward, - "assign_done": assign_done, - } - else: - if stack_method is None: - stack_method = as_padded_tensor - kwargs = { - "dataloader": DummyTensorDataLoader( - padding=True, batch_size=dl_batch_size - ), - "stack_method": stack_method, - "repeats": repeats, - "assign_reward": assign_reward, - "assign_done": assign_done, - } - kwargs.update( - { - "batch_size": env_batch_size, - "from_text": from_text, - "device": device, - "has_attention": False, - } - ) - with pytest.warns(UserWarning, match="eos_token_id"): - env = LLMEnv.from_dataloader(**kwargs) - # We want to make sure that transforms that rely on the done state work appropriately - env.append_transform(StepCounter(max_steps=10)) - - def policy(td): - td[LLMEnv._DEFAULT_ACTION_TOKENS_KEY] = torch.ones( - td.shape + (torch.randint(10, (1,)).item(),), dtype=torch.int64 - ) - return td - - r = env.rollout(100, policy, break_when_any_done=False) - if assign_done: - assert "terminated" in r - assert "done" in r - - class TestAsyncEnvPool: def make_env(self, *, makers, backend): return AsyncEnvPool(makers, backend=backend) diff --git a/torchrl/data/llm/__init__.py b/torchrl/data/llm/__init__.py index 81a56a780db..601aadcf1b1 100644 --- a/torchrl/data/llm/__init__.py +++ b/torchrl/data/llm/__init__.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .chat import History from .dataset import ( create_infinite_iterator, get_dataloader, @@ -12,21 +11,11 @@ ) from .prompt import PromptData, PromptTensorDictTokenizer from .reward import PairwiseDataset, RewardData -from .utils import ( - AdaptiveKLController, - ConstantKLController, - LLMData, - LLMInput, - LLMOutput, - RolloutFromModel, -) +from .utils import AdaptiveKLController, ConstantKLController, RolloutFromModel __all__ = [ "AdaptiveKLController", "ConstantKLController", - "LLMData", - "LLMInput", - "LLMOutput", "PairwiseDataset", "PromptData", "PromptTensorDictTokenizer", @@ -36,5 +25,4 @@ "TokenizedDatasetLoader", "create_infinite_iterator", "get_dataloader", - "History", ] diff --git a/torchrl/data/llm/chat.py b/torchrl/data/llm/chat.py deleted file mode 100644 index a4d6cfd09c5..00000000000 --- a/torchrl/data/llm/chat.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import dataclasses - -import re -from typing import Literal - -import torch - -from tensordict import lazy_stack, LazyStackedTensorDict, list_to_stack, TensorClass -from tensordict.utils import _maybe_correct_neg_dim -from torchrl._utils import logger as torchrl_logger - -_TEMPLATES = { - "chatml_format": """{% for message in messages %} - {{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}} -{% endfor %} -{%- if add_generation_prompt %} - {{- '<|im_start|>assistant\n' }} -{%- endif %} -""", -} - - -class History(TensorClass["nocast"]): - """A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models. - - The `History` class provides a centralized API for managing conversational data, offering several advantages over - traditional list-based approaches: - - - Centralized API for conversion to and from string formats, facilitating seamless integration with language models. - - Efficient methods to append, extend, and reshape history elements, enabling dynamic construction of conversation - trajectories, especially useful in reinforcement learning environments. - - Interoperability with the `transformers` API, allowing for easy tokenization and preparation of input data. - - Attributes: - role (str): The role of the message sender. - content (str): The content of the message. - - Methods: - apply_chat_template: converts the `History` object to str / tokens. - append: append one element to the list of items along a given dimension. - extend: extend the list of items along a given dimension. - - Examples: - >>> # With tensordict < 0.10, we need to tell the lib that lists constitute batches - >>> import tensordict - >>> tensordict.set_list_to_stack(True).set() - >>> import transformers - >>> history0 = History( - ... role='system', - ... content='''CONTENT - ... This is the setup''', - ... ) - >>> history1 = History( - ... role='user', - ... content='''CONTENT - ... This is the first user prompt''', - ... ) - >>> history2 = History( - ... role='assistant', - ... content='''CONTENT - ... This is the second prompt, the first for the assistant.''', - ... ) - >>> history = torch.stack([history0, history1, history2]) - >>> assert history.role == ['system', 'user', 'assistant'] - >>> tokenizer = transformers.AutoTokenizer.from_pretrained("GPT2") - >>> # Apply a template to pass the history to an LLM. Note that the output has - >>> # an additional prompt to elict an answer from the LLM thanks to the 'add_generation_prompt' argument. - >>> parsed_string = history.apply_chat_template(tokenizer=tokenizer, add_generation_prompt=True) - >>> parsed_string - <|im_start|>system - CONTENT - This is the setup<|im_end|> - - <|im_start|>user - CONTENT - This is the first user prompt<|im_end|> - - <|im_start|>assistant - CONTENT - This is the second prompt, the first for the assistant.<|im_end|> - - <|im_start|>assistant - - """ - - role: str - content: str - - def __post_init__(self): - if not list_to_stack(): - raise RuntimeError( - "Please set the list_to_stack to True using tensordict.set_list_to_stack(True).set() at the beginning of your script, " - "or the LIST_TO_STACK=1 environment variable." - ) - - def apply_chat_template( - self, - *, - tokenizer: transformers.PreTrainedTokenizer, # noqa - add_generation_prompt: bool = True, - chat_template: str = _TEMPLATES["chatml_format"], - continue_final_message: bool = False, - tokenize: bool = False, - padding: bool | str = False, - truncation: bool | str = False, - return_tensors: str | None = "pt", - **kwargs, - ): - """Applies a chat template to the history. - - Keyword Args: - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use. - add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to True. - chat_template (str, optional): The chat template to use. Defaults to _TEMPLATES["chatml_format"]. - continue_final_message (bool, optional): Whether to continue the final message. Defaults to False. - tokenize (bool, optional): Whether to tokenize the output. Defaults to False. - padding (bool | str, optional): The padding strategy to use. Defaults to False. - truncation (bool | str, optional): The truncation strategy to use. Defaults to False. - return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt". - **kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method. - - Returns: - The formatted history. - """ - self_flat = self.view(-1).tolist() - return tokenizer.apply_chat_template( - self_flat, - add_generation_prompt=add_generation_prompt, - chat_template=chat_template, - tokenize=tokenize, - padding=padding, - truncation=truncation, - return_tensors=return_tensors, - continue_final_message=continue_final_message, - ) - - @classmethod - def inv_chat_template( - cls, text: str, chat_template_name: Literal["chatml_format"] = "chatml_format" - ) -> History: - if chat_template_name not in ("chatml_format",): - # Hard coded for now - raise NotImplementedError( - "chat_template_name must be one of ('chatml_format',)" - ) - if isinstance(text, list): - return torch.stack([cls._inv_chatml(text) for text in text]) - return cls._inv_chatml(text) - - @classmethod - def _inv_chatml(cls, text: str) -> History: - """Inverts a chatml string into a History object. - - Args: - text (str): The chatml string to invert. - - Returns: - History: The inverted History object. - """ - torchrl_logger.debug(f"Inverting chatml:\n{text}") - pattern = r"<\|im_start\|>(.*?)\n(.*?)<\|im_end\|>" - matches = re.findall(pattern, text, flags=re.DOTALL) - roles = [] - contents = [] - for match in matches: - role = match[0].strip() - - # Override role - # role = "assistant" - content = match[1].strip() - roles.append(role) - contents.append(content) - if not roles: - raise RuntimeError( - f"Couldn't get a single item out of text {text}. A common cause " - f"if that special tokens should not be ommitted, did you set include_stop_str_in_output/skip_special_tokens=False?" - ) - - return cls( - role=roles, - content=contents, - batch_size=len(roles), - ) - - def append( - self, history: History, *, inplace: bool = True, dim: int = 0 - ) -> History: - """Appends a new history to the current one. - - Args: - history (History): The new history to append. - inplace (bool, optional): Whether to perform the operation in-place. Defaults to True. - dim (int, optional): The dimension to append along. Defaults to 0. - - Returns: - History: The appended History object. - """ - if not self.batch_dims: - raise RuntimeError( - "Cannot append an element to a batchless History. Call unsqueeze(dim=0) first on self." - ) - if self.batch_dims != history.batch_dims + 1: - raise RuntimeError( - f"The new history to append must have one less dimension than self. Got self.ndim={self.ndim} and history.ndim={history.ndim}." - ) - if inplace: - dim = _maybe_correct_neg_dim(dim, self.batch_size) - if ( - isinstance(self._tensordict, LazyStackedTensorDict) - and self._tensordict.stack_dim == dim - ): - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self._tensordict.append(td) - return self - else: - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - td = lazy_stack(list(self._tensordict.unbind(dim)) + [td], dim=dim) - self.__dict__["_tensordict"] = td - return self - if history.device != self.device: - if self.device is None: - history = history.copy().clear_device_() - else: - history = history.to(self.device) - return torch.stack(list(self.unbind(dim)) + [history], dim=dim) - - def extend( - self, history: History, *, inplace: bool = True, dim: int = 0 - ) -> History: - if not self.batch_dims: - raise RuntimeError( - "Cannot add an element to a batchless History. Call unsqueeze(dim=0) first on self." - ) - if self.batch_dims != history.batch_dims: - raise RuntimeError( - f"The new history to extend must have as many dimensions as self. Got self.ndim={self.ndim} and history.ndim={self.ndim}." - ) - if inplace: - dim = _maybe_correct_neg_dim(dim, self.batch_size) - if ( - isinstance(self._tensordict, LazyStackedTensorDict) - and self._tensordict.stack_dim == dim - ): - td = history._tensordict - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self._tensordict.extend(td) - return self - else: - td = lazy_stack( - list(self._tensordict.unbind(dim)) - + list(history._tensordict.unbind(dim)), - dim=dim, - ) - if td.device != self.device: - if self.device is None: - td = td.copy().clear_device_() - else: - td = td.to(self.device) - self.__dict__["_tensordict"] = td - return self - if history.device != self.device: - if self.device is None: - history = history.copy().clear_device_() - else: - history = history.to(self.device) - return torch.stack(list(self.unbind(dim)) + list(history.unbind(dim)), dim=dim) - - @classmethod - def default_spec(cls, shape=(-1,)): - """A default spec to use in transforms / envs that return History objects. - - Args: - shape (torch.Size, optional): The shape of the returned History spec. Defaults to `(-1)` (variable length - along time dimension). - - Example: - >>> import tensordict - >>> from torchrl.data import History - >>> tensordict.set_list_to_stack(True).set() - >>> - >>> history = History(role=["system", "user"], content=["a message", "another message"], batch_size=(2,)) - >>> spec = history.default_spec() - >>> print(spec) - Composite( - role: NonTensor( - shape=torch.Size([-1]), - space=None, - device=None, - dtype=None, - domain=None, - example_data=foo), - content: NonTensor( - shape=torch.Size([-1]), - space=None, - device=None, - dtype=None, - domain=None, - example_data=foo), - device=None, - shape=torch.Size([-1])) - >>> print(spec.zero()) - History( - content=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), - role=NonTensorData(data=foo, batch_size=torch.Size([1]), device=None), - batch_size=torch.Size([1]), - device=None, - is_shared=False) - - """ - from torchrl.data import Composite, NonTensor - - def get_default_value(field): - if field.default is not dataclasses.MISSING: - return field.default - elif field.type in (str, "str"): - return "foo" - else: - return None - - defaults = { - k: NonTensor( - example_data=get_default_value(cls.__dataclass_fields__[k]), shape=(-1,) - ) - for k in cls.__dataclass_fields__ - } - - return Composite(defaults, shape=shape, data_cls=cls) diff --git a/torchrl/data/llm/utils.py b/torchrl/data/llm/utils.py index 6ffa3641467..9bb71bb02f3 100644 --- a/torchrl/data/llm/utils.py +++ b/torchrl/data/llm/utils.py @@ -7,11 +7,10 @@ import abc import collections import importlib -from typing import TypeVar import numpy as np import torch -from tensordict import TensorClass, TensorDict +from tensordict import TensorDict from torch import nn, Tensor from torch.nn import functional as F @@ -542,96 +541,3 @@ def step_scheduler(self): # remove all values while len(self._kl_queue): self._kl_queue.remove(self._kl_queue[0]) - - -LLMInpOut = TypeVar("LLMInpOut") - - -class LLMInput(TensorClass["nocast"]): - """Represents the input to a Large Language Model (LLM). - - Attributes: - tokens (torch.Tensor): The input tokens as a tensor. - attention_mask (torch.Tensor, optional): The attention mask for the input tokens. Default to `None`. - token_list (list[int] | list[list[int]], optional): The input tokens as a list of integers or a list of lists of integers. Default to `None`. - text (str | list[str], optional): The input text as a string or a list of strings. Default to `None`. - - .. seealso:: :class:`~torchrl.data.LLMOutput` and :class:`~torchrl.data.LLMData`. - - """ - - tokens: torch.Tensor - attention_mask: torch.Tensor | None = None - token_list: list[int] | list[list[int]] | None = None - text: str | list[str] | None = None - - -class LLMOutput(TensorClass["nocast"]): - """Represents the output from a Large Language Model (LLM). - - Attributes: - tokens (torch.Tensor): The output tokens as a tensor. - tokens_response (torch.Tensor, optional): The response tokens generated by the model. Default to `None`. - - .. note:: the reponse is the sequence of tokens output by a model, excluding the input - tokens. - - token_list (list[int] | list[list[int]], optional): The output tokens as a list of integers or a list of lists of integers. Default to `None`. - tokens_response_list (list[list[int]], optional): The response tokens generated by the model as a list of lists of integers. Default to `None`. - logits (torch.Tensor, optional): The logits of the output tokens. Default to `None`. - log_probs (torch.Tensor, optional): The log probabilities of the output tokens. Default to `None`. - text (str | list[str], optional): The output text as a string or a list of strings. Default to `None`. - - .. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMData`. - - """ - - tokens: torch.Tensor - tokens_response: torch.Tensor | None = None - token_list: list[int] | list[list[int]] | None = None - tokens_response_list: list[list[int]] | None = None - logits: torch.Tensor | None = None - log_probs: torch.Tensor | None = None - text: str | list[str] | None = None - - @classmethod - def from_vllm_output(cls: type[LLMInpOut], vllm_output) -> LLMInpOut: - # placeholder - raise NotImplementedError - - -class LLMData(TensorClass["nocast"]): - """Represents the input or output of a Large Language Model (LLM). - - Other algorithm-specific attributes such as `reward`, `advantages` or done states are handled automatically by the - envs and, therefore, are not included in this class. - - Attributes: - tokens (torch.Tensor): The input/output tokens as a tensor. - attention_mask (torch.Tensor, optional): The attention mask for the input tokens. Default to `None`. - tokens_response (torch.Tensor, optional): The response tokens generated by the model. Default to `None`. - - .. note:: the reponse is the sequence of tokens output by a model, excluding the input - tokens. - - token_list (list[int] | list[list[int]], optional): The output tokens as a list of integers or a list of lists - of integers. Default to `None`. - tokens_response_list (list[list[int]], optional): The response tokens generated by the model as a list of - lists of integers. Default to `None`. - logits (torch.Tensor, optional): The logits of the output tokens. Default to `None`. - log_probs (torch.Tensor, optional): The log probabilities of the output tokens. Default to `None`. - text (str | list[str], optional): The output text as a string or a list of strings. Default to `None`. - - .. seealso:: :class:`~torchrl.data.LLMInput` and :class:`~torchrl.data.LLMOutput`. - - """ - - tokens: torch.Tensor | None = None - tokens_response: torch.Tensor | None = None - attention_mask: torch.Tensor | None = None - token_list: list[int] | list[list[int]] | None = None - tokens_response_list: list[list[int]] | None = None - logits: torch.Tensor | None = None - log_probs: torch.Tensor | None = None - text: str | list[str] | None = None - text_response: torch.Tensor | None = None diff --git a/torchrl/envs/custom/llm.py b/torchrl/envs/custom/llm.py index 6dec617858b..4d323e9e577 100644 --- a/torchrl/envs/custom/llm.py +++ b/torchrl/envs/custom/llm.py @@ -4,29 +4,15 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import warnings - -from typing import Any, Callable, Literal +from typing import Callable import torch -from tensordict import ( - is_leaf_nontensor, - LazyStackedTensorDict, - NestedKey, - set_list_to_stack, - TensorDict, - TensorDictBase, - unravel_key, -) +from tensordict import NestedKey, set_list_to_stack, TensorDict, TensorDictBase from tensordict.tensorclass import NonTensorData, NonTensorStack -from tensordict.utils import _zip_strict -from torch.utils.data import DataLoader -from torchrl._utils import _replace_last from torchrl.data.map.hash import SipHash from torchrl.data.tensor_specs import ( - Bounded, Categorical as CategoricalSpec, Composite, NonTensor, @@ -34,551 +20,6 @@ ) from torchrl.envs import EnvBase from torchrl.envs.utils import _StepMDP -from torchrl.modules.utils.utils import _unpad_tensors - - -class LLMEnv(EnvBase): - """A text generation environment for language models. - - This environment is designed to work with language models, where the observation is a string or a tensor of - integers representing a sequence of tokens. The action is also a string or a tensor of integers, which is - concatenated to the previous observation to form the new observation. - - By default, this environment is meant to track history for a prompt. Users can append transforms to tailor - this to their use case, such as Chain of Thought (CoT) reasoning or other custom processing. - - Users must append a transform to set the "done" condition, which would trigger the loading of the next prompt. - Prompts to the language model can be loaded when the environment is ``reset`` if the environment is created via - :meth:`~from_dataloader`. - - .. note:: The default arguments of the `LLMEnv` class are set to make it easy to run this environment with - the vllm backend (:class:`~torchrl.modules.vLLMWrapper`). - - Keyword Args: - token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `from_text=False`). - Defaults to ``"tokens"``. - str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `from_text=True`). - Defaults to ``"text"``. - attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. - Defaults to ``"attention_mask"``. - action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to - ``"tokens_response"`` or ``"text_response"``. - reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. - Defaults to ``"reward"``. - from_text (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True``. - device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. - vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an - unbounded vocabulary. Defaults to ``None``. - has_attention (bool, optional): If ``True``, an attention mask is to be used under the key indicated by - :attr:`attention_key`. Defaults to ``True``. - assign_reward (bool, optional): If ``True``, a zero-valued reward of shape equal to the action shape - is written during calls to `step()`. Defaults to ``False``. - assign_done (bool, optional): If ``True``, a zero-valued done and terminated state of shape equal to the - action shape is written during calls to `step()`. Defaults to ``False``. - .. note:: Regardless of the value assigned to `assign_done`, a done state will be written at the root - as it is a requirement for all TorchRL environments. - batch_size (int or torch.Size, optional): Batch size of the environment. - If left empty, an empty batch-size is assumed. - The batch size can be null (`torch.Size([])`) or one-dimensional. Batchless environments are not supported. - - .. note:: When using a :class:`~torchrl.envs.DataLoadingPrimer` transform, the batch-size of the env - and the transform should match. - - eos_token_id (int, optional): The token id of the end of the sequence. If passed, the `done` state - is set to `True` when detected. Defaults to `None`. - - .. seealso:: :class:`~torchrl.envs.DataLoadingPrimer` for examples. - - Methods: - from_dataloader: Creates an LLMEnv instance from a dataloader. - - """ - - _DEFAULT_TOKEN_KEY = "tokens" - _DEFAULT_STR_KEY = "text" - _DEFAULT_ATTENTION_KEY = "attention_mask" - _DEFAULT_ACTION_TOKENS_KEY = "tokens_response" - _DEFAULT_ACTION_STR_KEY = "text_response" - - def __init__( - self, - *, - token_key: NestedKey | None = None, - str_key: NestedKey | None = None, - attention_key: NestedKey | None = None, - action_key: NestedKey | None = None, - reward_key: NestedKey = "reward", - from_text: bool = True, - device: torch.device | None = None, - vocab_size: int | None = None, - assign_reward: bool = False, - assign_done: bool = False, - batch_size: int | torch.Size | None = None, - has_attention: bool = True, - # Experimental - as_llm_data: bool = False, - eos_token_id: int | None = None, - ) -> None: - self.as_llm_data = as_llm_data - if token_key is None: - token_key = self._DEFAULT_TOKEN_KEY - if str_key is None: - str_key = self._DEFAULT_STR_KEY - if attention_key is None: - attention_key = self._DEFAULT_ATTENTION_KEY - if action_key is None: - if from_text: - action_key = self._DEFAULT_ACTION_STR_KEY - else: - action_key = self._DEFAULT_ACTION_TOKENS_KEY - self._batch_locked = True - if batch_size is None: - batch_size = () - else: - if not isinstance(batch_size, (tuple, list)): - batch_size = (batch_size,) - elif len(batch_size) > 1: - raise TypeError( - f"batch-size of LLMEnv must be 0 or 1d. Got batch_size={batch_size}." - ) - super().__init__( - device=device, - batch_size=batch_size, - ) - self.has_attention = has_attention - self.from_text = from_text - self.vocab_size = vocab_size - self.token_key = unravel_key(token_key) - self.str_key = unravel_key(str_key) - if attention_key is not None: - attention_key = unravel_key(attention_key) - self.attention_key = attention_key - self.assign_reward = assign_reward - self.assign_done = assign_done - self.eos_token_id = eos_token_id - if eos_token_id is None: - warnings.warn( - "eos_token_id is missing. This means that the environment will not be able to capture its " - "done state automatically. This may lead to undefined behaviors when the generated text reaches " - "an eos_token.", - category=UserWarning, - ) - - # self.action_key = unravel_key(action_key) - if from_text: - self.full_observation_spec_unbatched = Composite( - { - self.str_key: NonTensor( - example_data="a string", - batched=True, - shape=(), - device=device, - ) - } - ) - self.full_action_spec_unbatched = Composite( - { - action_key: NonTensor( - example_data="a string", batched=True, shape=(), device=device - ) - } - ) - else: - if vocab_size is None: - observation_spec = { - token_key: Unbounded(shape=(-1,), dtype=torch.int64, device=device) - } - if self.has_attention: - observation_spec[attention_key] = Unbounded( - shape=(-1,), dtype=torch.int64, device=device - ) - self.full_observation_spec_unbatched = Composite(observation_spec) - self.full_action_spec_unbatched = Composite( - { - action_key: Unbounded( - shape=(-1,), dtype=torch.int64, device=device - ) - } - ) - else: - self.full_observation_spec_unbatched = Composite( - { - token_key: Bounded( - shape=(-1,), - dtype=torch.int64, - low=0, - high=vocab_size, - device=device, - ) - } - ) - self.full_action_spec_unbatched = Composite( - { - action_key: Bounded( - shape=(-1,), - dtype=torch.int64, - low=0, - high=vocab_size, - device=device, - ) - } - ) - STR2STR_ERR = ValueError( - "from_text cannot be True when either of assign_reward / assign_done are True. " - "Tokens are required to compute the reward shape." - ) - if self.assign_reward: - if self.from_text: - raise STR2STR_ERR - self.full_reward_spec_unbatched = Composite( - {reward_key: Unbounded(shape=(-1,), device=device)} - ) - else: - self.full_reward_spec_unbatched = Composite(device=device) - - if not self.assign_done: - # Use single done - self.full_done_spec_unbatched = Composite( - done=Unbounded(shape=(1,), dtype=torch.bool, device=device), - terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device), - ) - elif self.from_text: - raise STR2STR_ERR - else: - # Use single done - self.full_done_spec_unbatched = Composite( - tokens_data=Composite( - done=Unbounded(shape=(-1,), dtype=torch.bool, device=device), - terminated=Unbounded(shape=(-1,), dtype=torch.bool, device=device), - ), - done=Unbounded(shape=(1,), dtype=torch.bool, device=device), - terminated=Unbounded(shape=(1,), dtype=torch.bool, device=device), - ) - - @classmethod - def from_dataloader( - cls, - dataloader: DataLoader, - *, - tokenizer: transformers.PretrainedTokenizerBase | None = None, # noqa - token_key: NestedKey | None = None, - str_key: NestedKey | None = None, - attention_key: NestedKey | None = None, - action_key: NestedKey | None = None, - reward_key: NestedKey = "reward", - from_text: bool = True, - device: torch.device | None = None, - vocab_size: int | None = None, - batch_size: int | torch.Size | None = None, - has_attention: bool = True, - assign_reward: bool = False, - assign_done: bool = False, - primers: Composite | None = None, - example_data: Any = None, - stack_method: Callable[[Any], Any] - | Literal["as_nested_tensor", "as_padded_tensor"] = None, - repeats: int | None = None, - group_repeats: bool = True, - eos_token_id: int | None = None, - ) -> LLMEnv: - """Creates an LLMEnv instance from a dataloader. - - This method creates an LLMEnv instance and appends a DataLoadingPrimer to it, which populates ``data_keys`` (by default ``observation_key``) with data from the provided dataloader when the environment is reset. - - Args: - dataloader (DataLoader): The dataloader to load data from. - - Keyword Args: - tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``, - "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a - pre-trained tokenizer. - - .. note:: Using the `tokenizer` will append a :class:`~torchrl.envs.Tokenizer` transform to the environment. - If `from_text` is set to `True`, the tokenizer will be called during every iteration and the rollout - will contain both tokens and text data. - If `from_text` is set to `False`, the tokenizer will be called during reset only, and the only - text data in the rollout will be the text sampled from the dataset. - - token_key (NestedKey, optional): The key in the tensordict where the tokens are stored (when `from_text=False`). - Defaults to ``("tokens_in", "input_ids")``. - str_key (NestedKey, optional): The key in the tensordict where the string input is stored (when `from_text=True`). - Defaults to ``"test"``. - attention_key (NestedKey, optional): The key in the tensordict where the attention mask is stored. - Defaults to ``("tokens_in", "input_ids")`` - action_key (NestedKey, optional): The key in the tensordict where the action is stored. Defaults to - ``("tokens_out", "sequences")``. - reward_key (NestedKey, optional): The key in the tensordict where the reward is stored if `assign_reward=True`. - Defaults to ``"reward"``. - from_text (bool, optional): Whether the environment should expect strings as input and output. Defaults to ``True``. - device (torch.device | None, optional): The device on which the environment should run. Defaults to ``None``. - vocab_size (int | None, optional): The size of the vocabulary. If None, the environment will assume an - unbounded vocabulary. Defaults to ``None``. - has_attention (bool, optional): if ``True``, an attention mask is to be used under the key indicated by - :attr:`attention_key`. Defaults to ``True``. - assign_reward (bool, optional): if ``True``, a zero-valued reward of shape equal to to the action shape - is written during calls to `step()`. Defaults to ``False``. - assign_done (bool, optional): if ``True``, a zero-valued done and terminated state of shape equal to to the - action shape is written during calls to `step()`. Defaults to ``False``. - - .. note:: regardless of the value assigned to `assign_done`, a done state will be written at the root - as it is a requirement for all TorchRL environments. - - batch_size (int or torch.Size, optional): Batch size of the environment. - If left empty, the batch size is inferred from `dataloader.batch_size` if that attribute exists, otherwise - it is set to `()`. - The batch size can be null (`torch.Size([])`) or one-dimensional. Batchless environments are not supported. - - .. note:: When using a :class:`~torchrl.envs.DataLoadingPrimer` transform, the batch-size of the env - and the transform should match. - - primers (Composite | None, optional): The primers to use for each key in the dataloader. - Defaults to ``None`` (inferred automatically from the first batch of data). - stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The - method to use for stacking the data. Defaults to ``None``. - repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in - situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo - samples (rather than an advantage module). - group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that - all repeats are grouped in a single batch collected from the buffer. Defaults to ``True``. - eos_token_id (int, optional): The token id of the end of the sequence. If passed, the `done` state - is set to `True` when detected. Defaults to `None`. - - Returns: - LLMEnv: The created LLMEnv instance. - """ - from torchrl.envs import DataLoadingPrimer, Tokenizer - - if str_key is None: - str_key = LLMEnv._DEFAULT_STR_KEY - if token_key is None: - token_key = LLMEnv._DEFAULT_TOKEN_KEY - if attention_key is None: - attention_key = LLMEnv._DEFAULT_ATTENTION_KEY - elif tokenizer is not None and attention_key != _replace_last( - token_key, "attention_mask" - ): - raise ValueError( - "When using the Tokenizer, attention key must match `(*token_key[:-1], 'attention_mask')` where " - f"`token_key` is a tuple-typed nested key. Got attention_key={attention_key} while expecting " - f"{_replace_last(token_key, 'attention_mask')}." - ) - - if tokenizer is not None: - if from_text: - # In this case, the tokenizer is appended to the env after each step - if action_key is None: - action_key = cls._DEFAULT_ACTION_STR_KEY - tokenizer_transform = Tokenizer( - tokenizer=tokenizer, - in_keys=[str_key], - out_keys=[token_key], - # Assume that the tokens are named according to _DEFAULT_ACTION_TOKENS_KEY - in_keys_inv=[action_key], - out_keys_inv=[cls._DEFAULT_ACTION_TOKENS_KEY], - call_before_reset=False, - # We should always see the required entries - missing_tolerance=False, - ) - else: - # FIXME: This is broken - do we need it anyway? - raise RuntimeError( - "tokenizers can only be used whenever from_text is set to `True`." - ) - - primer = DataLoadingPrimer( - dataloader=dataloader, - primers=primers, - stack_method=stack_method, - repeats=repeats, - device=device, - group_repeats=group_repeats, - batch_size=batch_size, - ) - env = LLMEnv( - from_text=from_text, - device=device, - token_key=token_key, - str_key=str_key, - attention_key=attention_key, - action_key=action_key, - reward_key=reward_key, - vocab_size=vocab_size, - assign_reward=assign_reward, - assign_done=assign_done, - batch_size=primer.batch_size, - has_attention=has_attention, - eos_token_id=eos_token_id, - ) - if tokenizer is not None: - env = env.append_transform(tokenizer_transform) - return env.append_transform(primer) - - @staticmethod - def _check_obs_act_and_cat(obs, action, *, device): - if not isinstance(obs, str): - raise TypeError(f"Observation must be a string, got {type(obs)}.") - if not isinstance(action, str): - raise TypeError(f"Action must be a string, got {type(action)}.") - return NonTensorData(obs + action, device=device) - - def _step( - self, - tensordict: TensorDictBase, - ) -> TensorDictBase: - next_td = tensordict.empty() - self._make_next_obs(tensordict, next_td) - self._maybe_make_reward(tensordict, next_td) - self._maybe_make_done(tensordict, next_td) - if self.as_llm_data: - raise NotImplementedError() - return next_td - - def _maybe_make_reward( - self, tensordict: TensorDictBase, next_td: TensorDictBase - ) -> TensorDictBase: - if self.assign_reward: - next_td.set( - self.reward_key, - torch.zeros_like( - tensordict.get(self.action_key), dtype=self.reward_spec.dtype - ), - ) - return next_td - - def _maybe_make_done( - self, - tensordict: TensorDictBase, - next_td: TensorDictBase, - resetting: bool = False, - ) -> TensorDictBase: - if self.assign_done: - action = tensordict.get(self.action_key) - if action is None: - done = torch.zeros( - tensordict.shape + (1,), dtype=torch.bool, device=self.device - ) - else: - done = torch.zeros_like(action, dtype=torch.bool) - next_td.set(("tokens_data", "terminated"), done) - next_td.set(("tokens_data", "done"), done.clone()) - next_td.set( - "done", next_td.get(("tokens_data", "done")).any(-1, keepdim=True) - ) - next_td.set( - "terminated", - next_td.get(("tokens_data", "terminated")).any(-1, keepdim=True), - ) - if not resetting and self.eos_token_id is not None: - if self.from_text: - token_action_key = self._DEFAULT_ACTION_TOKENS_KEY - else: - token_action_key = self.action_key - action = tensordict.get( - token_action_key, as_padded_tensor=True, padding_value=-1 - ) - mask = action == -1 - - if action is None: - raise RuntimeError( - f"Couldn't find the tokenized action with key {token_action_key} to set the done state in tensordict " - f"with keys {list(tensordict.keys(True))}." - ) - full_done = action == self.eos_token_id - done = full_done.any(-1, keepdim=True) - next_td.set("done", done) - next_td.set("terminated", done) - if self.assign_done: - full_done = _unpad_tensors(full_done, mask) - next_td.set(("tokens_data", "terminated"), full_done) - next_td.set(("tokens_data", "done"), full_done) - return next_td - - def _make_next_obs( - self, tensordict: TensorDictBase, nex_td: TensorDictBase - ) -> TensorDictBase: - # Cat action entry with prev obs - if self.from_text: - obs = tensordict[self.str_key] - action = tensordict[self.action_key] - if not tensordict.batch_size: - if not isinstance(obs, str) or not isinstance(action, str): - raise TypeError( - "The tensordict is batchless, yet the action and/or observations are not " - f"strings but {type(action)} and {type(obs)}, respectivly." - ) - observation = self._check_obs_act_and_cat( - obs, action, device=self.device - ) - else: - observation = NonTensorStack( - *[ - self._check_obs_act_and_cat(_obs, _action, device=self.device) - for (_obs, _action) in _zip_strict(obs, action) - ] - ) - return nex_td.set(self.str_key, observation) - else: - try: - obs: torch.Tensor = tensordict.get(self.token_key) - action = tensordict.get(self.action_key) - if getattr(obs, "is_nested", False): - observation = torch.nested.as_nested_tensor( - [ - torch.cat([_obs, _action], -1) - for _obs, _action in _zip_strict( - obs.unbind(0), action.unbind(0) - ) - ], - layout=obs.layout, - ) - else: - observation = torch.cat([obs, action], -1) - if self.has_attention: - attention_mask = tensordict.get(self.attention_key) - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones(action.shape)], -1 - ) - nex_td.set(self.attention_key, attention_mask) - except TypeError: - raise TypeError( - "Failed to cat action and observation tensors. Check that from_text argument is correctly " - f"set in {type(self).__name__}." - ) - return nex_td.set(self.token_key, observation) - - def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase: - # We should have an observation by this time, if not raise an exception - def check_token(): - return not self.from_text and ( - self.token_key not in tensordict.keys(isinstance(self.token_key, tuple)) - ) - - def check_str(): - return self.from_text and ( - self.str_key not in tensordict.keys(isinstance(self.str_key, tuple)) - ) - - if tensordict is None or check_token() or check_str(): - raise KeyError( - f"Observation key {self.token_key}/{self.str_key} is not defined in tensordict with keys " - f"{list(tensordict.keys(True, True, is_leaf=is_leaf_nontensor))}. Make sure a TensorDictPrimer (eg, " - f"torchrl.envs.DataLoadingPrimer) is appended to the env transforms." - ) - if not isinstance(tensordict, LazyStackedTensorDict) and tensordict.ndim: - tensordict = LazyStackedTensorDict(*tensordict.unbind(0)) - td_reset = tensordict.copy() - if td_reset.device != self.device: - if self.device is None: - td_reset.clear_device_() - else: - td_reset = td_reset.to(self.device) - tensordict = self._maybe_make_done(tensordict, td_reset, resetting=True) - if self.as_llm_data: - raise NotImplementedError() - return tensordict - - def _set_seed(self, seed: int | None) -> None: - return seed class LLMHashingEnv(EnvBase): diff --git a/torchrl/envs/transforms/llm.py b/torchrl/envs/transforms/llm.py index e7fc5db94d4..13e421f32fd 100644 --- a/torchrl/envs/transforms/llm.py +++ b/torchrl/envs/transforms/llm.py @@ -4,521 +4,18 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import warnings -from collections import deque -from collections.abc import Mapping from copy import copy, deepcopy -from typing import Any, Callable, Iterable, Literal import torch -from tensordict import lazy_stack, NestedKey, TensorDict, TensorDictBase, unravel_key +from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key from tensordict.nn import ProbabilisticTensorDictModule, TensorDictParams from tensordict.utils import is_seq_of_nested_key from torch import nn from torchrl.data.tensor_specs import Composite, Unbounded from torchrl.envs.common import EnvBase -from torchrl.envs.transforms.transforms import TensorDictPrimer, Transform +from torchrl.envs.transforms.transforms import Transform from torchrl.envs.transforms.utils import _set_missing_tolerance, _stateless_param -from torchrl.envs.utils import make_composite_from_td - - -def as_nested_tensor(list_of_tensordicts: list[TensorDictBase]) -> TensorDictBase: - """Stacks a list of tensordicts into a single tensordict with nested tensors. - - Args: - list_of_tensordicts (list[TensorDictBase]): A list of tensordicts to stack. - - Returns: - TensorDictBase: A tensordict with nested tensors. - - """ - - def _as_nested_tensor(*list_of_tensors): - return torch.nested.as_nested_tensor(list_of_tensors, layout=torch.jagged) - - batch_size = list(list_of_tensordicts[0].shape) - batch_size.insert(0, len(list_of_tensordicts)) - return list_of_tensordicts[0].apply( - _as_nested_tensor, *list_of_tensordicts[1:], batch_size=batch_size - ) - - -def as_padded_tensor( - list_of_tensordicts: list[[TensorDictBase]], dim=0, stack_dim: int = 0 -) -> TensorDictBase: - """Stacks a list of tensordicts into a single tensordict with padded tensors. - - Args: - list_of_tensordicts (list[[TensorDictBase]]): A list of tensordicts to stack. - dim (int, optional): The dimension along which to pad. Defaults to 0. - stack_dim (int, optional): The dimension along which to stack. Defaults to 0. - - Returns: - TensorDictBase: A tensordict with padded tensors. - """ - - def _stack_tensors(*list_of_tensors): - if dim < 0: - raise ValueError("dim must be >= 0") - max_length = max([t.size(dim) for t in list_of_tensors]) - - def pad_tensor(tensor): - padding_length = max_length - tensor.size(dim) - shape = [ - s if i != dim else padding_length for i, s in enumerate(tensor.shape) - ] - return torch.cat((tensor.new_zeros(shape), tensor), dim=dim) - - return torch.stack([pad_tensor(t) for t in list_of_tensors], dim=stack_dim) - - batch_size = list(list_of_tensordicts[0].shape) - batch_size.insert(dim, len(list_of_tensordicts)) - result = list_of_tensordicts[0].apply( - _stack_tensors, *list_of_tensordicts[1:], batch_size=batch_size - ) - return result - - -class DataLoadingPrimer(TensorDictPrimer): - """A primer that loads data from a dataloader and converts it into a tensordict using ``stack_method``. - - Args: - dataloader (Iterable[Dict[str, Any]]): The dataloader to load data from. - During collection, we will attempt to convert it into a tensordict using :func:`~tensordict.from_dict` or a - similar function. - It is assumed that the elements retrieved from the dataloader come in batches along the first dimension - of every tensor, unless `dataloader.batch_size=0`. - The dataloader must yield mappable data structures (e.g., dictionaries). - - Keyword Args: - primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None. - stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to - use for stacking the data. Defaults to ``maybe_dense_stack``. - repeats (int, optional): How many times the same sample needs to appear successively. This can be useful in - situations like GRPO where a single prompt is used multiple times to estimate the advantage using Monte-Carlo - samples (rather than an advantage module). - batch_size (int, torch.Size or None): the batch-size of the data delivered by the transform. - This is somewhat unrelated to the batch-size of the dataloader, in the sense that this number may or may - not match the DL's batch size. - If left empty, the batch-size is inferred from `dataloader.batch_size` if that attribute exists. If not, - an empty batch-size will be used (`torch.Size([])`). - - .. note:: The batch-size of the Primer must match the batch-size of the parent environment (typically a - wrapper around :class:`~torchrl.envs.LLMEnv`). - - group_repeats (bool, optional): if ``True``, the batch-size is multiplied by the number of repeats such that - all repeats are grouped in a single batch collected from the buffer. Defaults to ``False``. - - Attributes: - dataloader (Iterable[Any]): The dataloader to load data from. - endless_dataloader (Iterable[Any]): An endless iterator over the dataloader. - stack_method (Callable[[Any], Any]): The method to use for stacking the data. - - .. seealso:: :class:`~torchrl.envs.LLMEnv` and :class:`~torchrl.envs.LLMEnv.from_dataloader`. - - Example of a dataloader yielding strings: - >>> import random - >>> import string - >>> import tensordict as td - >>> import torch - >>> from tensordict import TensorDict - >>> from torchrl.data import Unbounded - >>> from torchrl.envs import DataLoadingPrimer, LLMEnv - >>> td.set_capture_non_tensor_stack(False).set() - >>> class DummyDataLoader: - ... '''A dummy dataloader that generates random strings.''' - ... def __init__(self, batch_size: int = 0): - ... self.batch_size = batch_size - ... def generate_random_string(self, length: int = 10) -. str: - ... '''Generate a random string of a given length.''' - ... return ''.join(random.choice(string.ascii_lowercase) for _ in range(length)) - ... def __iter__(self): - ... return self - ... def __next__(self): - ... if self.batch_size == 0: - ... return self.generate_random_string() - ... else: - ... return [self.generate_random_string() for _ in range(self.batch_size)] - >>> # Create an LLM environment with string-to-string input/output. - >>> env = LLMEnv(from_text=True) - >>> # Append a DataLoadingPrimer to the environment. - >>> env = env.append_transform( - >>> DataLoadingPrimer( - >>> dataloader=DummyDataLoader(), - >>> example_data="a string!", - >>> ) - >>> ) - >>> # Test the environment. - >>> print(env.rand_action(TensorDict())) - TensorDict( - fields={ - action: NonTensorData(data=a string, batch_size=torch.Size([]), device=None)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) - >>> print(env.rollout(3)) - TensorDict( - fields={ - action: NonTensorStack( - ['a string', 'a string', 'a string'], - batch_size=torch.Size([3]), - device=None), - done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: NonTensorStack( - ['zxwvupirska string', 'zxwvupirska stringa string..., - batch_size=torch.Size([3]), - device=None), - terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([3]), - device=None, - is_shared=False), - observation: NonTensorStack( - ['zxwvupirsk', 'zxwvupirska string', 'zxwvupirska ..., - batch_size=torch.Size([3]), - device=None), - terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([3]), - device=None, - is_shared=False) - >>> # Roll out the environment with a specific initial state. - >>> init_state = env.reset(TensorDict(batch_size=[3])) - >>> print(env.rollout(3, auto_reset=False, tensordict=init_state)) - TensorDict( - fields={ - action: NonTensorStack( - [['a string', 'a string', 'a string'], ['a string'..., - batch_size=torch.Size([3, 3]), - device=None), - done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: TensorDict( - fields={ - done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: NonTensorStack( - [[array(['nngcmflsana string', 'vrrbnhzpmga string..., - batch_size=torch.Size([3, 3]), - device=None), - terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([3, 3]), - device=None, - is_shared=False), - observation: NonTensorStack( - [['nngcmflsan', array(['nngcmflsana string', 'vrrb..., - batch_size=torch.Size([3, 3]), - device=None), - terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - batch_size=torch.Size([3, 3]), - device=None, - is_shared=False) - - Example of dataloader yielding tensors: - >>> import random - >>> import string - >>> - >>> import tensordict as td - >>> import torch - >>> from tensordict import TensorDict - >>> from torchrl.data import Unbounded - >>> from torchrl.envs import DataLoadingPrimer, LLMEnv - >>> - >>> td.set_capture_non_tensor_stack(False).set() - >>> - >>> - >>> class DummyTensorDataLoader: - ... '''A dummy dataloader that generates tensors of random int64 values.''' - ... - ... def __init__(self, batch_size: int = 0, max_length: int = 10, padding: bool = False): - ... ''' - ... Args: - ... batch_size (int, optional): The batch size of the generated tensors. Defaults to 0. - ... max_length (int, optional): The maximum length of the generated tensors. Defaults to 10. - ... padding (bool, optional): Whether to pad the tensors to the maximum length. Defaults to False. - ... ''' - ... self.batch_size = batch_size - ... self.max_length = max_length - ... self.padding = padding - ... - ... def generate_random_tensor(self) -. torch.Tensor: - ... '''Generate a tensor of random int64 values.''' - ... length = random.randint(1, self.max_length) - ... return torch.tensor([random.randint(0, 100) for _ in range(length)], dtype=torch.int64) - ... - ... def pad_tensor(self, tensor: torch.Tensor) -. torch.Tensor: - ... '''Pad a tensor to the maximum length.''' - ... padding_length = self.max_length - len(tensor) - ... return torch.cat((torch.zeros(padding_length, dtype=torch.int64), tensor)) - ... - ... def __iter__(self): - ... return self - ... - ... def __next__(self): - ... if self.batch_size == 0: - ... tensor = self.generate_random_tensor() - ... return self.pad_tensor(tensor) if self.padding else tensor - ... else: - ... tensors = [self.generate_random_tensor() for _ in range(self.batch_size)] - ... if self.padding: - ... tensors = [self.pad_tensor(tensor) for tensor in tensors] - ... return torch.stack(tensors) - ... else: - ... return tensors - >>> - >>> # Create an LLM environment with non-string input/output and append a DataLoadingPrimer. - >>> env = LLMEnv(from_text=False) - >>> env = env.append_transform( - >>> DataLoadingPrimer( - >>> dataloader=DummyTensorDataLoader(), - >>> data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], - >>> ) - >>> ) - >>> print(env.rand_action(TensorDict())) - TensorDict( - fields={ - action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False)}, - batch_size=torch.Size([]), - device=None, - is_shared=False) - >>> print(env.rollout(3)) - LazyStackedTensorDict( - fields={ - action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False), - done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: LazyStackedTensorDict( - fields={ - done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([3, -1]), device=cpu, dtype=torch.int64, is_shared=False), - terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - exclusive_fields={ - }, - batch_size=torch.Size([3]), - device=None, - is_shared=False, - stack_dim=0), - observation: Tensor(shape=torch.Size([3, -1]), device=cpu, dtype=torch.int64, is_shared=False), - terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - exclusive_fields={ - }, - batch_size=torch.Size([3]), - device=None, - is_shared=False, - stack_dim=0) - >>> # Create an LLM environment with padded tensor input/output and append a DataLoadingPrimer. - >>> env = LLMEnv(from_text=False) - >>> env = env.append_transform( - >>> DataLoadingPrimer( - >>> dataloader=DummyTensorDataLoader(padding=True), - >>> data_specs=[Unbounded(shape=(-1,), dtype=torch.int64)], - >>> stack_method="as_padded_tensor", - >>> ) - >>> ) - >>> print(env.rollout(3, auto_reset=False, tensordict=env.reset(TensorDict(batch_size=[3])))) - LazyStackedTensorDict( - fields={ - action: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.int64, is_shared=False), - done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - next: LazyStackedTensorDict( - fields={ - done: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - observation: Tensor(shape=torch.Size([3, 3, -1]), device=cpu, dtype=torch.int64, is_shared=False), - terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - exclusive_fields={ - }, - batch_size=torch.Size([3, 3]), - device=None, - is_shared=False, - stack_dim=1), - observation: Tensor(shape=torch.Size([3, 3, -1]), device=cpu, dtype=torch.int64, is_shared=False), - terminated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False), - truncated: Tensor(shape=torch.Size([3, 3, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, - exclusive_fields={ - }, - batch_size=torch.Size([3, 3]), - device=None, - is_shared=False, - stack_dim=1) - - """ - - def __init__( - self, - dataloader: Iterable[dict[str, Any]], - *, - primers: Composite | None = None, - stack_method: Callable[[Any], Any] - | Literal["as_nested_tensor", "as_padded_tensor"] = None, - batch_size: int | torch.Size | None = None, - repeats: int | None = None, - device: torch.device | None = None, - group_repeats: bool = False, - ): - self.dataloader = dataloader - if repeats is None: - repeats = 0 - self.repeats = repeats - - # Determine batch-size - # We must distinguish the batch-size of the DL and the batch size of the transform. - # We may want more or less elements than the DL and the logic is slightly different so we - # allow to recompose batches on the fly. If the DL has a batch-size, every element will be - # unbound and stored in a queue. Otherwise, we get as many elements from the DL to fulfill - # the required batch-size. - # - # If the batch-size is passed, we will stack as many elements as necessary to fulfill this. - # If not, we try to get it from the dataloader. Contrary to the dataloader, we will always - # deliver the same batch-size (we create an infinite dataloader and reset when it's done), - # whereas DLs with drop_last=False may return batches of different sizes. - # - # If the batch size passed to the transform is empty (torch.Size(())) or 0, we will consider that - # the batch-size is determined on-the-fly. - # - # A batch-size of 0 in the dataloader means no batch-size. - # - # If needed, the various repeats can be grouped in a single batch through group_repeats. - # - # If auto_batch_size is on, we call auto_batch_size=True when doing TensorDict.from_dict: - # That way we get a tensordict of the right batch-size. - # If the dataloader has no batch-size, we're not sure that we can determine the batch-size - # automatically so we will consider that each element in the DL has a batch-size of 0 (ie, - # a single non-batched element is returned at a time). - - if batch_size is None: - batch_size = getattr(dataloader, "batch_size", torch.Size([])) - if batch_size == 0: - batch_size = torch.Size(()) - if not isinstance(batch_size, (list, tuple)): - batch_size = (batch_size,) - batch_size = torch.Size(batch_size) - auto_batch_size = getattr(dataloader, "batch_size", 1) != 0 - - if len(batch_size) > 1: - raise ValueError( - f"batch_size can only be 0 or 1D, got batch_size={batch_size}." - ) - - # We deliver all the repeats in the same batch - if repeats and group_repeats: - if batch_size == torch.Size([]): - batch_size = torch.Size((repeats,)) - else: - batch_size = torch.Size([batch_size[0] * repeats]) - - self._queue = deque() - self.auto_batch_size = auto_batch_size - self.batch_size = batch_size - self.endless_dataloader = self._endless_iter(self.dataloader) - - if stack_method is None: - stack_method = lazy_stack - elif stack_method == "as_nested_tensor": - stack_method = as_nested_tensor - elif stack_method == "as_padded_tensor": - stack_method = as_padded_tensor - elif not callable(stack_method): - raise ValueError(f"Unknown stack_method={stack_method}") - self.stack_method = stack_method - - if primers is None: - # We can get the primer from the dataloader itself - data = self._load_from_dataloader() - primers = make_composite_from_td(data, dynamic_shape=True) - if batch_size: - primers = primers.expand(batch_size) - self._queue.insert(0, data) - self.data_keys = list(primers.keys(True, True)) - else: - self.data_keys = list(primers.keys(True, True)) - - super().__init__( - primers=primers, - default_value=self._load_from_dataloader, - reset_key=None, - expand_specs=None, - single_default_value=True, - call_before_env_reset=True, - device=device, - ) - self._reset_key = "_reset" - - @classmethod - def _endless_iter(self, obj): - while True: - yield from obj - - def _load_from_dataloader(self, reset: torch.Tensor | None = None): - """Loads a single element from the dataloader, or alternatively from the buffer. - - If `reset` is passed, then one element per reset will be loaded. - """ - if reset is not None: - if not reset.any(): - raise RuntimeError("reset must have at least one True value.") - if reset.ndim > 0: - loaded = [self._load_from_dataloader() for _ in range(reset.sum())] - return self.stack_method(loaded) - - primers = getattr(self, "primers", None) - if primers is not None: - device = self.primers.device - else: - device = None - - if len(self._queue) > 0: - result = self._queue.popleft() - if result.device != device: - result = result.to(device) - return result - - data = next(self.endless_dataloader) - # Some heuristic here: - # if data is a map, assume its keys match the keys in spec - # TODO: one could rename the keys too - if isinstance(data, Mapping): - out = TensorDict.from_dict( - data, - auto_batch_size=self.auto_batch_size, - batch_dims=int(bool(self.auto_batch_size or self.batch_size)), - device=device, - ) - else: - raise TypeError( - "Data loader must return a mapping that can be automatically cast to a tensordict. Check that you have " - "the appropriate collate_fn in your dataloader to do so." - ) - if not out.ndim: - out = out.unsqueeze(0) - self._queue.extend( - [d for d in out.unbind(0) for _ in range(max(1, self.repeats))] - ) - out = self._queue.popleft() - return out - - def set_container(self, container: Transform | EnvBase) -> None: - result = super().set_container(container) - # Check batch size - parent = getattr(self, "parent", None) - if ( - self.batch_size is not None - and parent is not None - and parent.batch_size != self.batch_size - ): - warnings.warn( - f"The parent env has a different batch size than the {type(self).__name__} transform." - ) - return result - - def __repr__(self) -> str: - class_name = self.__class__.__name__ - return f"{class_name}(primers={self.primers}, dataloader={self.dataloader})" class KLRewardTransform(Transform): diff --git a/torchrl/modules/__init__.py b/torchrl/modules/__init__.py index 925dd8d6bf0..4b9b0933bac 100644 --- a/torchrl/modules/__init__.py +++ b/torchrl/modules/__init__.py @@ -93,7 +93,6 @@ ) from .utils import get_primers_from_module from .planners import CEMPlanner, MPCPlannerBase, MPPIPlanner # usort:skip -from .llm import CategoricalSequential, TransformersWrapper, vLLMWrapper __all__ = [ "Actor", @@ -109,7 +108,6 @@ "Conv3dNet", "ConvNet", "DTActor", - "CategoricalSequential", "DdpgCnnActor", "DdpgCnnQNet", "DdpgMlpActor", @@ -178,8 +176,6 @@ "VmapModule", "WorldModelWrapper", "distributions_maps", - "TransformersWrapper", - "vLLMWrapper", "get_primers_from_module", "recurrent_mode", "reset_noise", diff --git a/torchrl/modules/llm/__init__.py b/torchrl/modules/llm/__init__.py deleted file mode 100644 index 78b160c50be..00000000000 --- a/torchrl/modules/llm/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from .common import CategoricalSequential -from .transformers_wrapper import TransformersWrapper - -from .vllm_wrapper import vLLMWrapper - -__all__ = ["TransformersWrapper", "vLLMWrapper", "CategoricalSequential"] diff --git a/torchrl/modules/llm/common.py b/torchrl/modules/llm/common.py deleted file mode 100644 index 8beab73d43b..00000000000 --- a/torchrl/modules/llm/common.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import torch -from tensordict import NestedKey, TensorDictBase -from tensordict.nn import TensorDictModuleBase, TensorDictSequential -from torch import distributions as D -from torch.distributions import Categorical -from torchrl.modules import MaskedCategorical - - -class CategoricalSequential(TensorDictModuleBase): - """A ProbabilisticTensorDictSequential subclass meant to work with LLMs. - - .. seealso:: :class:`~tensordict.nn.ProbabilisticTensorDictSequential` class. - - """ - - generate: bool - - def get_dist( - self, - tensordict: TensorDictBase, - tensordict_out: TensorDictBase | None = None, - as_padded_tensor: bool | None = None, - as_nested_tensor: bool | None = None, - padding_value: float | None = None, - padding_side: str = "right", - layout: torch.layout | None = None, - **kwargs, - ) -> D.Distribution: - td_out = self(tensordict.copy()) - # By default, pad and use masked categorical - if as_padded_tensor is None: - as_padded_tensor = as_nested_tensor is not True - if padding_value is None: - padding_value = 0.0 - if as_nested_tensor is None: - as_nested_tensor = False - logits = td_out.get( - "logits", - as_padded_tensor=as_padded_tensor, - as_nested_tensor=as_nested_tensor, - padding_value=padding_value, - padding_side=padding_side, - layout=layout, - ) - if as_padded_tensor: - # We can use MaskedCategorical - dist = MaskedCategorical( - logits=logits, - mask=logits != padding_value, - # use_cross_entropy=True, - ) - return dist - return Categorical(logits) - - # Sampling is taken care of by the sub-modules - forward = TensorDictSequential.forward - - @property - def log_prob_keys(self) -> list[NestedKey]: - return getattr(self, "_log_prob_keys", ["log_probs"]) - - @log_prob_keys.setter - def log_prob_keys(self, value: list[NestedKey]): - self._log_prob_keys = value - - @property - def log_prob_key(self) -> NestedKey: - return self.log_prob_keys[0] - - @log_prob_key.setter - def log_prob_key(self, value: NestedKey) -> None: - self.log_prob_keys[0] = value - - @property - def dist_params_keys(self) -> list[NestedKey]: - raise NotImplementedError - - @property - def dist_sample_keys(self) -> list[NestedKey]: - return ["tokens_response"] - - def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase: - if not self.generate: - data = self(data) - return data.get(self.log_prob_key, **get_kwargs) - raise RuntimeError("log_prob not callable when generate=True.") diff --git a/torchrl/modules/llm/transformers_wrapper.py b/torchrl/modules/llm/transformers_wrapper.py deleted file mode 100644 index 13e41f47901..00000000000 --- a/torchrl/modules/llm/transformers_wrapper.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -from typing import Literal - -import torch -from tensordict import ( - lazy_stack, - LazyStackedTensorDict, - NestedKey, - set_list_to_stack, - TensorDict, - TensorDictBase, -) -from tensordict.utils import _zip_strict -from torch.nn.utils.rnn import pad_sequence - -from torchrl.modules.llm.common import CategoricalSequential -from torchrl.modules.utils.utils import _unpad_tensors - - -class TransformersWrapper(CategoricalSequential): - """A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation. - - This class handles both text and token inputs, enabling text generation and log probability computation based on - the specified configuration. Unlike vLLM, Transformers require padded tensors for input and output sequences. - - Args: - model (transformers.LLM): The Hugging Face Transformers model to wrap. - - Keyword Args: - return_log_probs (bool | None, optional): Whether to return log probabilities of the generated tokens. - Defaults to `None`. - tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | None, optional): The tokenizer to use for - encoding and decoding text. If `None`, the tokenizer associated with the model will be used. Defaults to - `None`. - from_text (bool, optional): Indicates whether the input is in text format. If `True`, the input is expected to - be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `True`. - device (torch.device | None, optional): The device to use for computation. If `None`, the default device will - be used. Defaults to `None`. - generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on - the input. If `False`, only log probabilities will be computed for the response tokens/text. Defaults to `True`. - generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. These - arguments can control aspects of the generation process, such as temperature and top-k sampling. Defaults - to `None`. - tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. These arguments can - control aspects of the tokenization process, such as padding and truncation. Defaults to `None`. - pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Transformers require - `pad_output=True`, and the output sequences will be padded and represented as tensors. Defaults to `True`. - inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place - operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be - created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will - conserve type, batch-size, and device). Defaults to `True`. - - .. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also - required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper - tokenization and padding. - - Input Keys: - - - If `from_text` is `True`: - - - `"text"`: The input text to be tokenized. - - `"text_response"`: the response text (if `generate=False` as the log probabilities are computed for the response.) - - - If `from_text` is `False`: - - - "tokens": The input token sequences. - - "attention_mask": The attention mask for the tokens. - - "tokens_response": The response token sequences (if `generate=False` as the log probabilities are - computed for the response.) - - Output Keys: - - - `"tokens_response"`: The generated token sequences. - - `"log_probs"`: The log probabilities of the generated tokens (if `return_log_probs` is `True`). - - `"text_response"`: The generated text (if `from_text` is `True` and `generate` is `True`). - - Example: - >>> from transformers import AutoModelForCausalLM, AutoTokenizer - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> wrapper = TransformersWrapper( - ... model, - ... tokenizer=tokenizer, - ... from_text=True, - ... generate=True - ... ) - >>> input_data = TensorDict({"text": ["Hello, world!", "This is another text"]}, batch_size=1) - >>> output_data = wrapper(input_data) - >>> print(output_data["text_response"]) - - .. seealso:: :func:`~torchrl.modules.vLLMWrapper` for a similar interface using vLLM. - - """ - - text_key: NestedKey = ("text",) - token_key: NestedKey = ("tokens",) - token_response_key: NestedKey = ("tokens_response",) - text_response_key: NestedKey = ("text_response",) - attention_mask_key: NestedKey = ("attention_mask",) - - def __init__( - self, - model: transformers.LLM, # noqa - # noqa - *, - return_log_probs: bool | None = None, - tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa - | None = None, - # noqa - from_text: bool = True, - device: torch.device | None = None, - generate: bool = True, - generate_kwargs: dict | None = None, - tokenizer_kwargs: dict | None = None, - pad_output: bool = True, - inplace: Literal[True, False, "empty"] | None = True, - ): - super().__init__() - - self.model = model - self.from_text = from_text - self._device = device - self.generate = generate - self.inplace = inplace - self.pad_output = pad_output - padding_value = None - - if not tokenizer_kwargs: - tokenizer_kwargs = {} - if not tokenizer_kwargs.setdefault("return_attention_mask", True): - raise RuntimeError - - # If we don't pad, we use lists - if not self.pad_output: - raise NotImplementedError("transformers requires `pad_output=True`.") - if tokenizer_kwargs.setdefault("return_tensors", "pt") != "pt": - raise RuntimeError - if tokenizer_kwargs.setdefault("padding", self.pad_output) not in ( - self.pad_output, - ): - raise RuntimeError - if tokenizer_kwargs.setdefault("padding_side", "left") != "left": - raise RuntimeError - - self.tokenizer_kwargs = tokenizer_kwargs - if (pad_output or (from_text and not generate)) and tokenizer is None: - # We need a tokenizer if we pad or when using text inputs with generate=False - # The latter case is due to the fact that we want the log-probs for the response only, - # but if the response is presented as a text we have to tokenize the whole prompt + response and - # identify where the prompt ends and where the response starts. - tokenizer = model.get_tokenizer() - self.tokenizer = tokenizer - if tokenizer is not None and ( - not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None - ): - self.tokenizer.pad_token = self.tokenizer.eos_token - if self.tokenizer is not None: - padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] - self.padding_value = padding_value - - if generate_kwargs is None: - generate_kwargs = {} - else: - generate_kwargs = dict(generate_kwargs) - - if not generate: - # TODO - if return_log_probs in (None, True): - return_log_probs = True - else: - raise ValueError( - "return_log_probs must be True or None when generate=False." - ) - elif return_log_probs in (None, False): - return_log_probs = False - self.return_log_probs = return_log_probs - - generate_kwargs.setdefault("tokenizer", self.tokenizer) - generate_kwargs.setdefault("output_logits", self.return_log_probs) - generate_kwargs.setdefault("return_dict_in_generate", True) - - self.generate_kwargs = generate_kwargs - - if from_text: - self.in_keys = [self.text_key] - else: - self.in_keys = [self.token_key, self.attention_mask_key] - self.out_keys = [self.token_response_key] - if from_text: - self.out_keys += [self.text_response_key, self.token_key] - if self.return_log_probs: - self.out_keys += [self.log_prob_key, "logits"] - - @set_list_to_stack(True) - def forward( - self, - tensordict: TensorDictBase, - tensordict_out: TensorDictBase | None = None, - **kwargs, - ) -> TensorDictBase: - if not tensordict.ndim: - # unsqueeze - squeeze the input - try: - return self(lazy_stack([tensordict]))[0] - except Exception as e: - raise RuntimeError( - f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional." - ) from e - _source_device = None - if self._device: - _source_device = tensordict.device - if tensordict.device: - tensordict = tensordict.copy().clear_device_() - - out = LazyStackedTensorDict( - *[ - TensorDict( - device=tensordict.device, batch_size=tensordict.batch_size[1:] - ) - for _ in range(tensordict.shape[0]) - ] - ) - if self.from_text: - if self.generate: - out = self._from_transformers_generate_text(tensordict, out=out) - else: - out = self._from_transformers_logprobs_text(tensordict, out=out) - else: - if self.generate: - out = self._from_transformers_generate_tokens(tensordict, out=out) - else: - out = self._from_transformers_logprobs_tokens(tensordict, out=out) - if _source_device: - out = out.to(_source_device) - - if tensordict_out is None: - if self.inplace is True: - tensordict_out = tensordict - elif self.inplace is False: - tensordict_out = TensorDict() - elif self.inplace == "empty": - tensordict_out = tensordict.empty() - - if tensordict_out is not None: - result = tensordict_out - result.update(out, keys_to_update=self.out_keys) - else: - result = out - keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) - return tensordict.update(result, keys_to_update=keys) - return result - - def _from_transformers_generate_text(self, td, out): - pad_val = self.tokenizer.pad_token_id - - text = td.get(self.text_key) - if not isinstance(text, (list, str)): - text = text.tolist() - tokens_in = self.tokenizer(text, **self.tokenizer_kwargs) - input_ids = tokens_in["input_ids"] - attention_mask = tokens_in["attention_mask"] - tokens_out = self.model.generate( - input_ids=input_ids, attention_mask=attention_mask, **self.generate_kwargs - ) - sequences = tokens_out["sequences"] - sequences = sequences[..., input_ids.shape[-1] :] - - mask_sequences = sequences != pad_val - sequences = _unpad_tensors(sequences, mask_sequences, as_nested=False) - if self.return_log_probs: - logits = torch.stack(list(tokens_out["logits"]), 1) - logits = _unpad_tensors(logits, mask_sequences, as_nested=False) - log_probs, logits = self._log_probs_generate( - sequences, logits, pad_val=pad_val - ) - response_text = self.tokenizer.batch_decode(sequences) - out.set(self.token_response_key, sequences) - out.set( - self.token_key, _unpad_tensors(input_ids, attention_mask, as_nested=False) - ) - out.set(self.text_response_key, list(response_text)) - out.set( - self.attention_mask_key, - _unpad_tensors(attention_mask, attention_mask, as_nested=False), - ) - if self.return_log_probs: - out.set(self.log_prob_key, log_probs) - out.set("logits", _unpad_tensors(logits, mask_sequences, as_nested=False)) - return out - - def _from_transformers_generate_tokens(self, td, out): - pad_val = self.tokenizer.pad_token_id - - input_ids = td.get( - self.token_key, - as_padded_tensor=True, - padding_side="left", - padding_value=pad_val, - ) - attention_mask = td.get( - self.attention_mask_key, - as_padded_tensor=True, - padding_side="left", - padding_value=0, - ) - if attention_mask is None: - attention_mask = (input_ids != pad_val).to(torch.int64) - tokens_out = self.model.generate( - input_ids=input_ids, attention_mask=attention_mask, **self.generate_kwargs - ) - sequences = tokens_out["sequences"] - sequences = sequences[:, input_ids.shape[-1] :] - mask_sequences = sequences != pad_val - sequences = _unpad_tensors(sequences, mask_sequences, as_nested=False) - - if self.return_log_probs: - logits = tokens_out["logits"] - logits = torch.stack(list(logits), 1) - logits = _unpad_tensors(logits, mask_sequences, as_nested=False) - log_probs, logits = self._log_probs_generate( - sequences, logits, pad_val=pad_val - ) - out.set( - self.token_response_key, - sequences, - ) - out.set( - self.token_key, _unpad_tensors(input_ids, attention_mask, as_nested=False) - ) - out.set( - self.attention_mask_key, - _unpad_tensors(attention_mask, attention_mask, as_nested=False), - ) - if self.return_log_probs: - out.set(self.log_prob_key, log_probs) - out.set("logits", _unpad_tensors(logits, mask_sequences, as_nested=False)) - return out - - def _from_transformers_logprobs_text(self, td, out): - pad_val = self.tokenizer.pad_token_id - - prompt_txt = td.get(self.text_key) - if not isinstance(prompt_txt, (list, str)): - prompt_txt = prompt_txt.tolist() - response_txt = td.get(self.text_response_key) - if not isinstance(response_txt, (list, str)): - response_txt = response_txt.tolist() - total_txt = [x + y for x, y in _zip_strict(prompt_txt, response_txt)] - total_tokens_in = self.tokenizer(total_txt, **self.tokenizer_kwargs) - prompt_tokens_in = self.tokenizer(prompt_txt, **self.tokenizer_kwargs) - - total_input_ids = total_tokens_in["input_ids"] - total_attention_mask = total_tokens_in["attention_mask"] - prompt_input_ids = prompt_tokens_in["input_ids"] - prompt_attention_mask = prompt_tokens_in["attention_mask"] - - total_tokens_out = self.model( - total_input_ids, attention_mask=total_attention_mask, **self.generate_kwargs - ) - - total_input_ids = _unpad_tensors( - total_input_ids, total_attention_mask, as_nested=False - ) - prompt_input_ids = _unpad_tensors( - prompt_input_ids, prompt_attention_mask, as_nested=False - ) - sequences = [ - _total_input_ids[_prompt_input_ids.shape[-1] :] - for _total_input_ids, _prompt_input_ids in zip( - total_input_ids, prompt_input_ids - ) - ] - # response_attention_mask = total_attention_mask[ - # :, prompt_attention_mask.shape[-1] : - # ] - log_probs, logits = self._log_probs_from_logits( - total_tokens_out, sequences, pad_val=pad_val - ) - - out.set("logits", logits) - out.set(self.log_prob_key, log_probs) - out.set(self.token_response_key, sequences) - return out - - def _from_transformers_logprobs_tokens(self, td, out): - pad_val = self.tokenizer.pad_token_id - - prompt_input_ids = td.get( - self.token_key, - as_list=True, - ) - response_input_ids = td.get( - self.token_response_key, - as_list=True, - ) - prompt_attention_mask = td.get( - self.attention_mask_key, - as_list=True, - ) - - total_input_ids = [ - torch.cat([_prompt_input_ids, _response_input_ids], -1) - for _prompt_input_ids, _response_input_ids in zip( - prompt_input_ids, response_input_ids - ) - ] - total_input_ids = pad_sequence( - total_input_ids, - padding_value=pad_val, - padding_side="left", - batch_first=True, - ) - total_attention_mask = (total_input_ids != pad_val).to(torch.int64) - - if prompt_attention_mask is None: - prompt_attention_mask = [ - (_prompt_input_ids != pad_val).to(torch.int64) - for _prompt_input_ids in prompt_input_ids - ] - - total_tokens_out = self.model( - total_input_ids, attention_mask=total_attention_mask, **self.generate_kwargs - ) - log_probs, logits = self._log_probs_from_logits( - total_tokens_out, response_input_ids, pad_val=pad_val - ) - - out.set("logits", logits) - out.set(self.log_prob_key, log_probs) - return out - - @classmethod - def _log_probs_from_logits(cls, total_tokens_out, response_input_ids, pad_val): - response_input_ids = pad_sequence( - response_input_ids, - padding_value=pad_val, - batch_first=True, - padding_side="left", - ) - pad_mask = response_input_ids != pad_val - - logits = total_tokens_out["logits"] - logits = logits.log_softmax(dim=-1) - logits = logits[:, -response_input_ids.shape[-1] - 1 : -1, :] - - log_probs = logits.gather(-1, response_input_ids.unsqueeze(-1)).squeeze(-1) - - # Recover the list - log_probs = _unpad_tensors(log_probs, pad_mask) - logits = _unpad_tensors(logits, pad_mask) - return log_probs, logits - - @classmethod - def _log_probs_generate(cls, sequences, logits, pad_val): - tokens = pad_sequence( - sequences, - padding_value=pad_val, - batch_first=True, - padding_side="left", - ) - logits = pad_sequence( - logits, - padding_value=0.0, - batch_first=True, - padding_side="left", - ) - - logits = logits.log_softmax(dim=-1) - log_probs = logits.gather(-1, tokens.unsqueeze(-1)).squeeze(-1) - return log_probs, logits diff --git a/torchrl/modules/llm/vllm_wrapper.py b/torchrl/modules/llm/vllm_wrapper.py deleted file mode 100644 index 58c5bf6609c..00000000000 --- a/torchrl/modules/llm/vllm_wrapper.py +++ /dev/null @@ -1,632 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -from __future__ import annotations - -import collections -from typing import Literal - -import torch -from tensordict import ( - lazy_stack, - maybe_dense_stack, - NestedKey, - TensorDict, - TensorDictBase, -) -from tensordict.tensorclass import from_dataclass, NonTensorStack, TensorClass -from tensordict.utils import _zip_strict, expand_as_right - -from torchrl.envs.utils import _classproperty -from torchrl.modules.llm import CategoricalSequential - - -class vLLMWrapper(CategoricalSequential): - """A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation, similar to the Hugging Face Transformers interface. - - This class allows for handling both text and token inputs, enabling text generation and log probability - computation based on the specified configuration. - - .. note:: The default arguments of the `vLLMWrapper` class are set to make it easy to run this backend with - the :class:`~torchrl.envs.custom.llm.LLMEnv` class. - - Args: - model (vllm.LLM): The vLLM model to wrap. - - Keyword Args: - return_log_probs (bool | None, optional): Whether to return log probabilities of the generated tokens. - Defaults to `None`. - tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | None, optional): The tokenizer to use for - encoding and decoding text. If `None`, the tokenizer associated with the model will be used. Defaults to - `None`. - from_text (bool, optional): Indicates whether the input is in text format. If `True`, the input is expected to - be text that will be tokenized. If `False`, the input is expected to be token sequences. Defaults to `True`. - device (torch.device | None, optional): The device to use for computation. If `None`, the default device will - be used. Defaults to `None`. - generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on - the input. If `False`, only log probabilities will be computed for the response tokens/text. Defaults to `True`. - generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. These - arguments can control aspects of the generation process, such as temperature and top-k sampling. Defaults - to `None`. - tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. These arguments can - control aspects of the tokenization process, such as padding and truncation. Defaults to `None`. - pad_output (bool, optional): Whether to pad the output sequences to a uniform length. If `True`, the output - sequences will be padded and represented as tensors. If `False`, lists of tokens will be used without - padding. Defaults to `False`. - - .. warning:: The default value of `pad_output` differs from :func:`~torchrl.modules.TransformersWrapper` - which does not handle non-padded inputs. - - inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place - operations. If `True`, operations will be performed in-place. If `False`, a new TensorDict instance will be - created. If `"empty"`, the output data structure will be initialized with `input.empty()` (i.e., it will - conserve type, batch-size, and device). Defaults to `True` when generating a single sample, `False` - otherwise. - - .. note:: The tokenizer is used when `from_text` is `True` to convert input text into token sequences. It is also - required (or retrieved) when `pad_output` is `True` or when using text inputs with `generate=False` to ensure proper - tokenization and padding. - - Input Keys: - - - If `from_text` is `True`: - - - `"text"`: The input text to be tokenized. - - `"text_response"`: the response text (if `generate=False` as the log probabilities are computed for the response.) - - - If `from_text` is `False`: - - - "tokens": The input token sequences. - - "attention_mask": The attention mask for the tokens. - - "tokens_response": The response token sequences (if `generate=False` as the log probabilities are - computed for the response.) - - Output Keys: - - - `"tokens_response"`: The generated token sequences. - - `"log_probs"`: The log probabilities of the generated tokens (if `return_log_probs` is `True`). - - `"text_response"`: The generated text (if `from_text` is `True` and `generate` is `True`). - - Example: - >>> from vllm import LLM - >>> from transformers import AutoTokenizer - >>> model = LLM("gpt2") - >>> wrapper = vLLMWrapper( - ... model, - ... from_text=True, - ... generate=True - ... ) - >>> input_data = LLMData(text=NonTensorStack("Hello, world!", "This is another text"), batch_size=1) - >>> output_data = wrapper(input_data) - >>> print(output_data.text_response) - - .. seealso:: :func:`~torchrl.modules.TransformersWrapper` for a similar interface using the Hugging Face - Transformers library. - """ - - text_key: NestedKey = ("text",) - token_key: NestedKey = ("tokens",) - token_response_key: NestedKey = ("tokens_response",) - text_response_key: NestedKey = ("text_response",) - attention_mask_key: NestedKey = ("attention_mask",) - - def __init__( - self, - model: vllm.LLM, # noqa - # noqa - *, - return_log_probs: bool | None = None, - tokenizer: transformers.tokenization_utils.PreTrainedTokenizer # noqa - | None = None, - # noqa - from_text: bool = True, - device: torch.device | None = None, - generate: bool = True, - generate_kwargs: dict | None = None, - tokenizer_kwargs: dict | None = None, - pad_output: bool = False, - inplace: Literal[True, False, "empty"] | None = None, - ): - super().__init__() - - from vllm import SamplingParams - - self.model = model - self.from_text = from_text - self._device = device - self.generate = generate - self.pad_output = pad_output - padding_value = None - - if not tokenizer_kwargs: - tokenizer_kwargs = {} - if not tokenizer_kwargs.setdefault("return_attention_mask", True): - raise RuntimeError - - # If we don't pad, we use lists - return_tensors = "pt" if self.pad_output else False - if return_tensors: - if ( - tokenizer_kwargs.setdefault("return_tensors", return_tensors) - != return_tensors - ): - raise RuntimeError - if tokenizer_kwargs.setdefault("padding", self.pad_output) not in ( - self.pad_output, - ): - raise RuntimeError - if tokenizer_kwargs.setdefault("padding_side", "left") != "left": - raise RuntimeError - - self.tokenizer_kwargs = tokenizer_kwargs - if (pad_output or (from_text and not generate)) and tokenizer is None: - # We need a tokenizer if we pad or when using text inputs with generate=False - # The latter case is due to the fact that we want the log-probs for the response only, - # but if the response is presented as a text we have to tokenize the whole prompt + response and - # identify where the prompt ends and where the response starts. - tokenizer = model.get_tokenizer() - self.tokenizer = tokenizer - if tokenizer is not None and ( - not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None - ): - self.tokenizer.pad_token = self.tokenizer.eos_token - if self.tokenizer is not None: - padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] - self.padding_value = padding_value - - if generate_kwargs is None: - generate_kwargs = {} - else: - generate_kwargs = dict(generate_kwargs) - - if generate_kwargs.get("n", 1) > 1: - if inplace in (True, "empty"): - raise ValueError( - "inplace must be False (or None) when generating more than one sample." - ) - if inplace is None: - inplace = False - elif inplace is None: - inplace = True - - self.inplace = inplace - - prompt_logprobs = False - - if not generate: - # We want only the log-probs, we generate a single token (that we then discard) - # and retrieve the prompt log-probs - generate_kwargs["max_tokens"] = 1 - prompt_logprobs = True - if return_log_probs in (None, True): - return_log_probs = True - else: - raise ValueError( - "return_log_probs must be True or None when generate=False." - ) - elif return_log_probs in (None, False): - return_log_probs = False - self.return_log_probs = return_log_probs - - generate_kwargs.setdefault("detokenize", not pad_output) - generate_kwargs.setdefault("prompt_logprobs", prompt_logprobs) - generate_kwargs.setdefault("logprobs", return_log_probs) - sampling_params = SamplingParams(**generate_kwargs) - self.sampling_params = sampling_params - - if from_text: - self.in_keys = [self.text_key] - else: - self.in_keys = [self.token_key, self.attention_mask_key] - self.out_keys = [self.token_response_key] - if from_text: - self.out_keys += [self.text_response_key, self.token_key] - if self.return_log_probs: - self.out_keys += [self.log_prob_key] - - def forward( - self, - tensordict: TensorDictBase, - tensordict_out: TensorDictBase | None = None, - **kwargs, - ) -> TensorDictBase: - if not tensordict.ndim: - # unsqueeze - squeeze the input - try: - return self(lazy_stack([tensordict]))[0] - except Exception as e: - raise RuntimeError( - f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional." - ) from e - - _source_device = None - if self._device: - _source_device = tensordict.device - if tensordict.device: - tensordict = tensordict.copy().clear_device_() - - if self.from_text: - if self.generate: - out = self._from_vllm_generate_text(tensordict) - else: - out = self._from_vllm_logprobs_text(tensordict) - else: - if self.generate: - out = self._from_vllm_generate_tokens(tensordict) - else: - out = self._from_vllm_logprobs_tokens(tensordict) - if _source_device: - out = out.to(_source_device) - - if tensordict_out is None: - if self.inplace is True: - tensordict_out = tensordict - elif self.inplace is False: - tensordict_out = out - elif self.inplace == "empty": - tensordict_out = tensordict.empty() - - if tensordict_out is not None and tensordict_out is not out: - result = tensordict_out - result.update(out, keys_to_update=self.out_keys) - elif tensordict_out is not out: - result = out - keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) - return tensordict.update(result, keys_to_update=keys) - else: - result = out - return result - - def _from_vllm_generate_text(self, td): - kwargs = {"sampling_params": self.sampling_params} - args = () - input_ids = None - attention_mask = None - if self.pad_output: - tokenizer_kwargs = self.tokenizer_kwargs - text = td.get(self.text_key) - if not isinstance(text, (list, str)): - text = text.tolist() - tokens_in = TensorDict.from_dict(self.tokenizer(text, **tokenizer_kwargs)) - # out.set("tokens_in", tokens_in) - input_ids, attention_mask = ( - tokens_in["input_ids"], - tokens_in["attention_mask"], - ) - prompt_token_ids = self._to_list(input_ids, attention_mask) - kwargs["prompt_token_ids"] = prompt_token_ids - else: - txt = td.get(self.text_key) - if not isinstance(txt, (list, str)): - txt = txt.tolist() - args = (txt,) - - tokens_out = self.model.generate(*args, **kwargs) - tokens_out = self._get_output_tokens_and_log_probs(tokens_out) - if self.pad_output: - tokens_out.set( - self.text_response_key, - NonTensorStack( - *self.tokenizer.batch_decode(tokens_out[self.token_response_key]) - ), - ) - in_keys = [ - self.log_prob_key, - self.token_response_key, - self.text_response_key, - self.token_key, - self.attention_mask_key, - ] - out = tokens_out.select(*in_keys, strict=False) - # We might already have the tokens - if input_ids is not None and self.token_key not in out: - out[self.token_key] = input_ids - if attention_mask is not None and self.attention_mask_key not in out: - out[self.attention_mask_key] = attention_mask - inputs = td.select(*self.in_keys, strict=False) - if inputs.ndim < out.ndim: - # This happens when n > 1 - inputs = inputs.unsqueeze(-1).expand(out.shape) - out.update(inputs) - return out - - def _from_vllm_logprobs_text(self, td): - text_prompt = td.get(self.text_key) - if not isinstance(text_prompt, list): - text_prompt = text_prompt.tolist() - text_response = td.get(self.text_response_key) - if not isinstance(text_response, list): - text_response = text_response.tolist() - text = [_x + _y for _x, _y in _zip_strict(text_prompt, text_response)] - - tokenized_total = self.tokenizer(text, **self.tokenizer_kwargs) - tokenized_prompt_only = self.tokenizer(text_prompt, **self.tokenizer_kwargs) - - input_ids_total = tokenized_total["input_ids"] - attention_mask_total = tokenized_total["attention_mask"] - - if not self.pad_output: - input_ids_prompt = tokenized_prompt_only["input_ids"] - attention_mask_prompt = tokenized_prompt_only["attention_mask"] - input_ids_response = [] - for token_total, token_prompt in zip(input_ids_total, input_ids_prompt): - input_ids_response.append(token_total[len(token_prompt) :]) - attention_mask_response = [] - for mask, mask_prompt in zip(attention_mask_total, attention_mask_prompt): - attention_mask_response.append(mask[len(mask_prompt) :]) - else: - input_ids_prompt: torch.Tensor = tokenized_prompt_only["input_ids"] - attention_mask_prompt: torch.Tensor = tokenized_prompt_only[ - "attention_mask" - ] - input_ids_response: torch.Tensor = input_ids_total[ - :, input_ids_prompt.shape[1] : - ] - # response_attention_mask: torch.Tensor = attention_mask_total[ - # :, attention_mask_prompt.shape[1] : - # ] - - input_ids_total = self._to_list(input_ids_total, attention_mask_total) - kwargs = {"sampling_params": self.sampling_params} - if self.tokenizer is not None: - kwargs.update({"prompt_token_ids": input_ids_total}) - args = () - else: - # TODO: this is unreachable as of now - but ultimately we may want to pass the text directly - args = (td[self.text_key],) - tokens_out = self.model.generate(*args, **kwargs) - tokens_out = _RequestOutput_tc.from_request_output(tokens_out) - tokens_out = tokens_out.select( - "prompt_token_ids", "prompt_logprobs", strict=False - )._tensordict - - # we disregard the tokens from the prompt to focus on those of the response - if self.pad_output: - lps = tokens_out.get( - "prompt_logprobs", as_padded_tensor=True, padding_side="left" - ) - lps = lps[..., -input_ids_response.shape[1] :] - padded = input_ids_response == self.padding_value - lps = torch.where(~padded, lps, 0.0) - else: - lps = tokens_out.get( - "prompt_logprobs", - as_list=True, - ) - # We use a nested tensor as it will be unbound during writing - lps = torch.nested.nested_tensor( - [lp[..., -len(tr) :] for lp, tr in zip(lps, input_ids_response)] - ) - - out = tokens_out.empty(recurse=True) - if isinstance(input_ids_response, list): - input_ids_response = torch.nested.nested_tensor(input_ids_response) - out["tokens_response"] = input_ids_response - out[self.log_prob_key] = lps - inputs = td.select(*self.in_keys, strict=False) - if inputs.ndim < out.ndim: - # This happens when n > 1 - inputs = inputs.unsqueeze(-1).expand(out.shape) - out.update(inputs) - return out - - def _from_vllm_generate_tokens(self, td): - input_ids = td.get(self.token_key) - attention_mask = td.get(self.attention_mask_key) - input_ids_list = self._to_list(input_ids, attention_mask) - args = () - kwargs = { - "sampling_params": self.sampling_params, - "prompt_token_ids": input_ids_list, - } - tokens_out = self.model.generate(*args, **kwargs) - tokens_out = _RequestOutput_tc.from_request_output(tokens_out) - # When not generate, we don't want to overwrite this - tokens_response_td = tokens_out.outputs._tensordict.select( - "token_ids", "logprobs", strict=False - ) - if self.pad_output: - tokens_response_td = tokens_response_td.densify( - layout=torch.strided - ).to_padded_tensor(padding=self.padding_value) - tokens_response_td.rename_key_("token_ids", "tokens_response") - if self.return_log_probs: - tokens_response_td.rename_key_("logprobs", self.log_prob_key) - if self.pad_output: - padded_values = ( - tokens_response_td["tokens_response"] == self.padding_value - ) - if padded_values.any(): - lps = tokens_response_td[self.log_prob_key] - lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) - tokens_response_td[self.log_prob_key] = lps - out = tokens_response_td.empty(recurse=True) - out.update( - tokens_response_td, - keys_to_update=(self.token_response_key, self.log_prob_key), - ) - inputs = td.select(*self.in_keys, strict=False) - if inputs.ndim < out.ndim: - # This happens when n > 1 - inputs = inputs.unsqueeze(-1).expand(out.shape) - out.update(inputs) - return out - - def _from_vllm_logprobs_tokens(self, td): - - tokens = td.get(self.token_key) - tokens_response = td.get(self.token_response_key) - attention_mask = td.get(self.attention_mask_key) - - tokens = torch.cat([tokens, tokens_response], -1) - if attention_mask is not None: - attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones(tokens_response.shape)], -1 - ) - input_ids_list = self._to_list(tokens, attention_mask) - args = () - kwargs = { - "sampling_params": self.sampling_params, - "prompt_token_ids": input_ids_list, - } - tokens_out = self.model.generate(*args, **kwargs) - tokens_out = _RequestOutput_tc.from_request_output(tokens_out) - prompt_logprobs = tokens_out.prompt_logprobs - prompt_logprobs = prompt_logprobs[..., -tokens_response.shape[-1] :] - padded = tokens_response == self.padding_value - prompt_logprobs = torch.where(~padded, prompt_logprobs, 0.0) - out = tokens_out._tensordict.empty(recurse=True) - out.set(self.log_prob_key, prompt_logprobs) - out.set(self.token_response_key, tokens_response) - inputs = td.select(*self.in_keys, strict=False) - if inputs.ndim < out.ndim: - # This happens when n > 1 - inputs = inputs.unsqueeze(-1).expand(out.shape) - out.update(inputs) - return out - - def _get_output_tokens_and_log_probs(self, tokens_out): - padding_value = self.padding_value - tokens_out = _RequestOutput_tc.from_request_output(tokens_out) - - # When not generate, we don't want to overwrite this - tokens_response_td = tokens_out.outputs._tensordict.select( - "text", "token_ids", "logprobs", strict=False - ) - if self.pad_output: - tokens_response_td = tokens_response_td.densify( - layout=torch.strided - ).to_padded_tensor(padding=padding_value) - tokens_response_td.rename_key_("token_ids", "tokens_response") - tokens_response_td.rename_key_("text", "text_response") - if not self.pad_output: - # Then we can safely move the input tokens, but otherwise they - # may need padding - tokens_out = tokens_out.select("prompt_token_ids") - if tokens_out.ndim < tokens_response_td.ndim: - tokens_out = tokens_out.unsqueeze(1).expand(tokens_response_td.shape) - tokens_response_td.update(tokens_out).rename_key_( - "prompt_token_ids", self.token_key - ) - - if self.return_log_probs or "logprobs" in tokens_response_td: - tokens_response_td.rename_key_("logprobs", self.log_prob_key) - if self.pad_output: - padded_values = tokens_response_td["tokens_response"] == padding_value - if padded_values.any(): - lps = tokens_response_td[self.log_prob_key] - lps = torch.where(expand_as_right(~padded_values, lps), lps, 0.0) - tokens_response_td[self.log_prob_key] = lps - return tokens_response_td - - def _to_list(self, tokens, attention_mask): - """Converts a tensor of integer in a masked list (of lists) of integers.""" - if isinstance(tokens, torch.Tensor): - # TODO: make this an ND NonTensorStack - parent = [] - queue = collections.deque() - if attention_mask is None: - attention_mask = torch.ones_like(tokens) - queue.append((tokens, attention_mask.bool(), parent)) - while queue: - token, amask, _parent = queue.popleft() - if token.ndim == 1: - _parent.extend(token[amask].tolist()) - else: - _parent.extend([[] for _ in range(token.shape[0])]) - queue.extend( - [ - (t, m, local_parent) - for t, m, local_parent in zip(token, amask, _parent) - ] - ) - tokens = parent - return tokens - - @_classproperty - def CompletionOutput_tc(cls): - import vllm - - if hasattr(cls, "_CompletionOutput_tc"): - return cls._CompletionOutput_tc - CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) - cls._CompletionOutput_tc = CompletionOutput_tc - return CompletionOutput_tc - - -class _RequestOutput_tc(TensorClass["nocast"]): - request_id: str - prompt: str - prompt_token_ids: str - prompt_logprobs: str - outputs: str - finished: str - metrics: str - lora_request: str - encoder_prompt: str - encoder_prompt_token_ids: str - num_cached_tokens: str - - def __post_init__(self): - CompletionOutput_tc = vLLMWrapper.CompletionOutput_tc - - def postproc(output): - def get_logprob(output): - t = [] - for v, tid in zip(output.logprobs, output.token_ids): - t.append( - v[int(tid)]["logprob"] - if v[tid].get("logprob") is not None - else 0.0 - ) - return torch.tensor(t) - - if output.logprobs: - output.logprobs = get_logprob(output) - output.token_ids = torch.as_tensor(output.token_ids) - return output - - if isinstance(self.outputs, list): - outputs = self.outputs - outputs = [ - postproc(from_dataclass(output, dest_cls=CompletionOutput_tc)) - for output in outputs - ] - if len(outputs) == 1: - self.outputs = outputs[0] - else: - self.outputs = maybe_dense_stack(outputs) - if self.prompt_logprobs is not None: - self.prompt_logprobs = torch.tensor( - [ - v[int(tid)].logprob if v is not None else 0.0 - for v, tid in _zip_strict( - self.prompt_logprobs, self.prompt_token_ids - ) - ] - ) - self.prompt_token_ids = torch.as_tensor(self.prompt_token_ids) - self.num_cached_tokens = torch.as_tensor(self.num_cached_tokens) - - @classmethod - def from_request_output(cls, requests): - out = lazy_stack( - [ - cls( - request_id=request.request_id, - prompt=request.prompt, - prompt_token_ids=request.prompt_token_ids, - prompt_logprobs=request.prompt_logprobs, - outputs=request.outputs, - finished=request.finished, - metrics=request.metrics, - lora_request=request.lora_request, - encoder_prompt=request.encoder_prompt, - encoder_prompt_token_ids=request.encoder_prompt_token_ids, - num_cached_tokens=request.num_cached_tokens, - ) - for request in requests - ] - ) - return out From 268dc3666eee4a90ed3c4221b3e8bf5d9aa13d2e Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 23 Apr 2025 07:58:36 +0100 Subject: [PATCH 2/5] Update [ghstack-poisoned] --- torchrl/data/__init__.py | 8 -------- torchrl/envs/__init__.py | 9 +-------- torchrl/envs/custom/__init__.py | 4 ++-- torchrl/envs/transforms/__init__.py | 6 ------ 4 files changed, 3 insertions(+), 24 deletions(-) diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index bf2e3a949b0..a1b09b524ef 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -8,10 +8,6 @@ ConstantKLController, create_infinite_iterator, get_dataloader, - History, - LLMData, - LLMInput, - LLMOutput, PairwiseDataset, PromptData, PromptTensorDictTokenizer, @@ -109,7 +105,6 @@ __all__ = [ "AdaptiveKLController", - "History", "Binary", "BinaryDiscreteTensorSpec", "BinaryToDecimal", @@ -131,9 +126,6 @@ "H5StorageCheckpointer", "HashToInt", "ImmutableDatasetWriter", - "LLMData", - "LLMInput", - "LLMOutput", "LazyMemmapStorage", "LazyStackStorage", "LazyStackedCompositeSpec", diff --git a/torchrl/envs/__init__.py b/torchrl/envs/__init__.py index 498162d01f5..de6c48c1402 100644 --- a/torchrl/envs/__init__.py +++ b/torchrl/envs/__init__.py @@ -6,7 +6,7 @@ from .async_envs import AsyncEnvPool, ProcessorAsyncEnvPool, ThreadingAsyncEnvPool from .batched_envs import ParallelEnv, SerialEnv from .common import EnvBase, EnvMetaData, make_tensordict -from .custom import ChessEnv, LLMEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv +from .custom import ChessEnv, LLMHashingEnv, PendulumEnv, TicTacToeEnv from .env_creator import env_creator, EnvCreator, get_env_metadata from .gym_like import default_info_dict_reader, GymLikeEnv from .libs import ( @@ -47,8 +47,6 @@ from .transforms import ( ActionDiscretizer, ActionMask, - as_nested_tensor, - as_padded_tensor, AutoResetEnv, AutoResetTransform, BatchSizeTransform, @@ -61,7 +59,6 @@ Compose, ConditionalSkip, Crop, - DataLoadingPrimer, DeviceCastTransform, DiscreteActionProjection, DoubleToFloat, @@ -155,7 +152,6 @@ "DMControlEnv", "DMControlWrapper", "DTypeCastTransform", - "DataLoadingPrimer", "DeviceCastTransform", "DiscreteActionProjection", "DoubleToFloat", @@ -182,7 +178,6 @@ "JumanjiEnv", "JumanjiWrapper", "KLRewardTransform", - "LLMEnv", "LLMHashingEnv", "LineariseRewards", "MOGymEnv", @@ -247,8 +242,6 @@ "VecNorm", "VmasEnv", "VmasWrapper", - "as_nested_tensor", - "as_padded_tensor", "check_env_specs", "check_marl_grouping", "default_info_dict_reader", diff --git a/torchrl/envs/custom/__init__.py b/torchrl/envs/custom/__init__.py index 24ffee4b3f1..9c98af1644a 100644 --- a/torchrl/envs/custom/__init__.py +++ b/torchrl/envs/custom/__init__.py @@ -4,8 +4,8 @@ # LICENSE file in the root directory of this source tree. from .chess import ChessEnv -from .llm import LLMEnv, LLMHashingEnv +from .llm import LLMHashingEnv from .pendulum import PendulumEnv from .tictactoeenv import TicTacToeEnv -__all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv", "LLMEnv"] +__all__ = ["ChessEnv", "LLMHashingEnv", "PendulumEnv", "TicTacToeEnv"] diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index bb485512e14..82b0555e1a5 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -5,9 +5,6 @@ from .gym_transforms import EndOfLifeTransform from .llm import ( - as_nested_tensor, - as_padded_tensor, - DataLoadingPrimer, KLRewardTransform, ) from .r3m import R3MTransform @@ -93,7 +90,6 @@ "ConditionalSkip", "Crop", "DTypeCastTransform", - "DataLoadingPrimer", "DeviceCastTransform", "DiscreteActionProjection", "DoubleToFloat", @@ -145,7 +141,5 @@ "VecGymEnvTransform", "VecNorm", "VecNormV2", - "as_nested_tensor", - "as_padded_tensor", "gSDENoise", ] From 0349c86ee4d66e5f7e40cb8e060e0a72393237a8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 23 Apr 2025 08:08:19 +0100 Subject: [PATCH 3/5] Update [ghstack-poisoned] --- torchrl/envs/transforms/__init__.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchrl/envs/transforms/__init__.py b/torchrl/envs/transforms/__init__.py index 82b0555e1a5..930c2cf6ebd 100644 --- a/torchrl/envs/transforms/__init__.py +++ b/torchrl/envs/transforms/__init__.py @@ -4,9 +4,7 @@ # LICENSE file in the root directory of this source tree. from .gym_transforms import EndOfLifeTransform -from .llm import ( - KLRewardTransform, -) +from .llm import KLRewardTransform from .r3m import R3MTransform from .rb_transforms import MultiStepTransform from .transforms import ( From 55780cd3f8d866f94d8ed609fbf27ee06b4670b6 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 23 Apr 2025 10:49:03 +0100 Subject: [PATCH 4/5] Update [ghstack-poisoned] --- test/test_rb.py | 197 ------------------------------------------------ 1 file changed, 197 deletions(-) diff --git a/test/test_rb.py b/test/test_rb.py index 54ff9b80e5c..d4577ae8ca9 100644 --- a/test/test_rb.py +++ b/test/test_rb.py @@ -25,7 +25,6 @@ is_tensor_collection, is_tensorclass, LazyStackedTensorDict, - set_list_to_stack, tensorclass, TensorDict, TensorDictBase, @@ -38,7 +37,6 @@ from torchrl.collectors.utils import split_trajectories from torchrl.data import ( FlatStorageCheckpointer, - History, MultiStep, NestedStorageCheckpointer, PrioritizedReplayBuffer, @@ -3919,201 +3917,6 @@ def test_multi_env(self, storage_type, checkpointer, tmpdir, frames_per_batch): assert rb._writer._cursor == rb_test._writer._cursor -class TestHistory: - @pytest.fixture(scope="class", autouse=True) - def set_context(self): - with set_list_to_stack(True): - yield - - def test_history_construct(self): - hst0 = History(role="user", content="a message") - assert not hst0.shape - hst1 = History(role="user", content="another message") - with pytest.raises(RuntimeError, match="unsqueeze"): - hst0.append(hst1) - hst0 = hst0.unsqueeze(0) - - # In an env.step, we typically have one more piece of history to add to the stack - assert not hst1.shape - assert not hst1.batch_size - assert not hst1.batch_dims - # test out-place - hst0_copy = hst0.copy() - hst0b = hst0.append(hst1, inplace=False) - assert hst0b is not hst0 - assert (hst0 == hst0_copy).all() - assert (hst0b[:-1] == hst0).all() - - # test in-place - hst0b = hst0.append(hst1) - assert hst0b is hst0 - assert hst0b.shape == (2,) - - assert hst0b.content == ["a message", "another message"] - hst2 = History( - role=["assistant", "user"], - content=["i'm the assistant", "i'm the user"], - batch_size=2, - ) - assert hst2[0].role == "assistant" - assert hst2[0].content == "i'm the assistant" - assert hst2[1].role == "user" - assert hst2[1].content == "i'm the user" - with pytest.raises(RuntimeError, match="The new history to extend"): - hst0.extend(hst1) - - # test out-place - hst0_copy = hst0.copy() - hst0b = hst0.extend(hst2, inplace=False) - assert hst0b is not hst0 - assert (hst0 == hst0_copy).all() - assert (hst0b[:-2] == hst0).all() - - # test in-place - hst0b = hst0.extend(hst2) - - assert hst0b is hst0 - assert hst0.__dict__["_tensordict"].shape == (4,) - assert hst0.shape == (4,) - assert hst0.role == ["user", "user", "assistant", "user"] - assert hst0.content == [ - "a message", - "another message", - "i'm the assistant", - "i'm the user", - ] - - def test_history_construct_ndim(self): - hst0 = History(role="user", content="a message").unsqueeze(0).unsqueeze(0) - hst1 = History(role="user", content="another message").unsqueeze(0) - - # test out-place - hst0_copy = hst0.copy() - hst0b = hst0.append(hst1, inplace=False, dim=1) - assert hst0b is not hst0 - assert (hst0 == hst0_copy).all() - assert (hst0b[:, :-1] == hst0).all() - - # test in-place - hst0b = hst0.append(hst1, dim=1) - assert hst0b is hst0 - assert hst0b.shape == ( - 1, - 2, - ) - - assert hst0b.content == [["a message", "another message"]] - hst2 = History( - role=["assistant", "user"], - content=["i'm the assistant", "i'm the user"], - batch_size=2, - ).unsqueeze(0) - - # test out-place - hst0_copy = hst0.copy() - hst0b = hst0.extend(hst2, inplace=False, dim=1) - assert hst0b is not hst0 - assert (hst0 == hst0_copy).all() - assert (hst0b[:, :-2] == hst0).all() - - # test in-place - hst0b = hst0.extend(hst2, dim=1) - - assert hst0b is hst0 - assert hst0.__dict__["_tensordict"].shape == ( - 1, - 4, - ) - assert hst0.shape == ( - 1, - 4, - ) - assert hst0.role == [["user", "user", "assistant", "user"]] - assert hst0.content == [ - [ - "a message", - "another message", - "i'm the assistant", - "i'm the user", - ] - ] - - @pytest.fixture(scope="class") - def mock_history(self): - history0 = History( - role="system", - content="""CONTENT - This is the setup""", - ) - history1 = History( - role="user", - content="""CONTENT - This is the first user prompt""", - ) - history2 = History( - role="assistant", - content="""CONTENT - This is the second prompt, the first for the assistant.""", - ) - history = torch.stack([history0, history1, history2]) - return history - - @pytest.fixture(scope="class") - def tokenizer(self): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained("GPT2") - yield tokenizer - - @pytest.mark.skipif(not _has_transformers, reason="requires transformers library") - def test_history_template(self, mock_history, tokenizer): - history = mock_history - data_str = history.apply_chat_template( - tokenizer=tokenizer, add_generation_prompt=False - ) - assert isinstance(data_str, str) - data_token = history.apply_chat_template( - tokenizer=tokenizer, tokenize=True, add_generation_prompt=False - ) - assert isinstance(data_token, torch.Tensor) - - # test add_generation_prompt - data_str = history.apply_chat_template( - tokenizer=tokenizer, add_generation_prompt=True - ) - assert isinstance(data_str, str) - assert data_str.endswith("<|im_start|>assistant\n"), data_str - - @pytest.mark.skipif(not _has_transformers, reason="requires transformers library") - def test_history_template_recover(self, mock_history, tokenizer): - history = mock_history - data_str = history.apply_chat_template(tokenizer=tokenizer) - # Test inverse - recovered = history._inv_chatml(data_str) - assert recovered.role == history.role - assert recovered.content == history.content - data_token = history.apply_chat_template( - tokenizer=tokenizer, tokenize=True, add_generation_prompt=False - ) - recovered = history._inv_chatml(tokenizer.batch_decode(data_token)[0]) - - def test_history_spec(self): - history = History( - role=["system", "user", "assistant", "user"], - content=[ - "i'm the system", - "i'm the user", - "I'm the assistant", - "I'm the user again", - ], - ) - spec = history.default_spec() - r = spec.zero() - assert isinstance(r, History) - assert spec.is_in(r) - assert spec.is_in(history) - - if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args() pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown) From 79e47af6b4998da665c2afedaa448ed29d3bc2ba Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 23 Apr 2025 11:03:55 +0100 Subject: [PATCH 5/5] Update [ghstack-poisoned] --- .github/unittest/linux_libs/scripts_llm/install.sh | 0 .github/unittest/linux_libs/scripts_llm/post_process.sh | 0 .github/unittest/linux_libs/scripts_llm/run-clang-format.py | 0 .github/unittest/linux_libs/scripts_llm/run_test.sh | 2 -- .github/unittest/linux_libs/scripts_llm/setup_env.sh | 0 5 files changed, 2 deletions(-) mode change 100755 => 100644 .github/unittest/linux_libs/scripts_llm/install.sh mode change 100755 => 100644 .github/unittest/linux_libs/scripts_llm/post_process.sh mode change 100755 => 100644 .github/unittest/linux_libs/scripts_llm/run-clang-format.py mode change 100755 => 100644 .github/unittest/linux_libs/scripts_llm/run_test.sh mode change 100755 => 100644 .github/unittest/linux_libs/scripts_llm/setup_env.sh diff --git a/.github/unittest/linux_libs/scripts_llm/install.sh b/.github/unittest/linux_libs/scripts_llm/install.sh old mode 100755 new mode 100644 diff --git a/.github/unittest/linux_libs/scripts_llm/post_process.sh b/.github/unittest/linux_libs/scripts_llm/post_process.sh old mode 100755 new mode 100644 diff --git a/.github/unittest/linux_libs/scripts_llm/run-clang-format.py b/.github/unittest/linux_libs/scripts_llm/run-clang-format.py old mode 100755 new mode 100644 diff --git a/.github/unittest/linux_libs/scripts_llm/run_test.sh b/.github/unittest/linux_libs/scripts_llm/run_test.sh old mode 100755 new mode 100644 index e4a3f2dad28..eab70cb6b4c --- a/.github/unittest/linux_libs/scripts_llm/run_test.sh +++ b/.github/unittest/linux_libs/scripts_llm/run_test.sh @@ -27,8 +27,6 @@ python -c "import transformers, datasets" pytest test/test_rlhf.py --instafail -v --durations 200 --capture no --error-for-skips -pytest test/test_actors.py test/test_collector.py -k llm --instafail -v --durations 200 --capture no --error-for-skips --runslow - pytest examples/rlhf/train_rlhf.py \ sys.device=cuda:0 sys.ref_device=cuda:0 \ model.name_or_path=gpt2 train.max_epochs=2 \ diff --git a/.github/unittest/linux_libs/scripts_llm/setup_env.sh b/.github/unittest/linux_libs/scripts_llm/setup_env.sh old mode 100755 new mode 100644