Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,28 @@ def test_pad_to_multiple_of_no_extra_padding(self):
expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]])
self.assertTrue(torch.equal(output, expected))

def test_pad_empty_list(self):
# Test that pad function handles empty list gracefully
output = pad([], padding_value=0, padding_side="right")
expected = torch.empty((0,), dtype=torch.int64)
self.assertTrue(torch.equal(output, expected))
self.assertEqual(output.shape, (0,))
self.assertEqual(output.dtype, torch.int64)

def test_pad_empty_list_with_padding_value(self):
# Test that pad function handles empty list with different padding values
output = pad([], padding_value=42, padding_side="left")
expected = torch.empty((0,), dtype=torch.int64)
self.assertTrue(torch.equal(output, expected))
self.assertEqual(output.shape, (0,))

def test_pad_empty_list_with_pad_to_multiple_of(self):
# Test that pad function handles empty list with pad_to_multiple_of
output = pad([], padding_value=0, padding_side="right", pad_to_multiple_of=4)
expected = torch.empty((0,), dtype=torch.int64)
self.assertTrue(torch.equal(output, expected))
self.assertEqual(output.shape, (0,))


@require_peft
class TestGetPEFTConfig(TrlTestCase):
Expand Down
13 changes: 13 additions & 0 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,19 @@ def _generate_and_score_completions(
)
completion_ids = [output.generated_tokens for output in all_outputs.values()]
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]

# Handle case where no completions are generated
if not completion_ids:
# Create empty completion tensors with the same batch size as prompts
batch_size = paged_prompt_inputs.input_ids.shape[0]
completion_ids = torch.empty((batch_size, 0), dtype=torch.int64, device=device)
prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# Restore the original attention implementation, training mode
self.model_wrapped.config._attn_implementation = previous_attn
return prompt_completion_ids

completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
Expand Down
9 changes: 9 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,6 +1112,15 @@ def _generate(self, model, prompts, images=None):
)
completion_ids = [output.generated_tokens for output in all_outputs.values()]
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]

# Handle case where no completions are generated
if not completion_ids:
# Create empty completion tensors with the same batch size as prompts
batch_size = prompt_ids.size(0)
completion_ids = torch.empty((batch_size, 0), dtype=torch.int64, device=device)
completion_mask = torch.empty((batch_size, 0), dtype=torch.bool, device=device)
return prompt_ids, prompt_mask, completion_ids, completion_mask

completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# Restore the original attention implementation, training mode
Expand Down
13 changes: 13 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,6 +1139,19 @@ def _generate_and_score_completions(
)
completion_ids = [output.generated_tokens for output in all_outputs.values()]
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]

# Handle case where no completions are generated
if not completion_ids:
# Create empty completion tensors with the same batch size as prompts
batch_size = paged_prompt_inputs.input_ids.shape[0]
completion_ids = torch.empty((batch_size, 0), dtype=torch.int64, device=device)
prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
# Restore the original attention implementation, training mode
self.model_wrapped.config._attn_implementation = previous_attn
return prompt_completion_ids

completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids]
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,12 @@ def pad(
[0, 0]]])
```
"""
# Handle empty tensors list
if not tensors:
# Return an empty tensor with shape (0,) and appropriate dtype
# We use int64 as default since most use cases involve token IDs
return torch.empty((0,), dtype=torch.int64)

# Determine the maximum shape for each dimension
output_shape = np.max([t.shape for t in tensors], 0).tolist()

Expand Down