Skip to content

Commit 7855c9d

Browse files
authored
Enable parallel sequence generation with adapters (#436)
1 parent 1bc2da4 commit 7855c9d

18 files changed

+122
-3
lines changed

src/transformers/adapters/context.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class ForwardContext:
7878
# thread-local storage that holds a stack of active contexts
7979
storage = threading.local()
8080

81-
context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions"]
81+
context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"]
8282

8383
def __init__(self, model, *args, **kwargs):
8484
# If the model has a method ``forward_context()``, use it to create the context.

src/transformers/adapters/heads/base.py

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ class MultiHeadOutput(ModelOutput):
3131
head_outputs: List[ModelOutput] = None
3232
loss: Optional[torch.FloatTensor] = None
3333

34+
@property
35+
def logits(self):
36+
return torch.vstack([outputs["logits"] for outputs in self.head_outputs])
37+
3438
def __getitem__(self, k):
3539
# with number indices the head output at that position is accessed
3640
# e.g output[1] is equivalent to output.head_outputs[1]

src/transformers/adapters/model_mixin.py

+5
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,11 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
779779
return
780780

781781
context.adapters_parallelized = False
782+
# Check if already parallelized in encoder
783+
adapter_input_parallelized = kwargs.pop("adapter_input_parallelized", None)
784+
if adapter_input_parallelized:
785+
if active_adapters.parallel_channels > 1:
786+
context.adapters_parallelized = True
782787
# Add the shared parameters for the active adapters to the context
783788
context.shared_parameters = {
784789
name: param

src/transformers/adapters/models/bart/adapter_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def forward(
8989
past_key_values=past_key_values,
9090
output_adapter_gating_scores=output_adapter_gating_scores,
9191
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
92+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
9293
)
9394
# sequence classification based on last token in sequence
9495
x = outputs[0] # last hidden state
@@ -139,6 +140,7 @@ def prepare_inputs_for_generation(
139140
"decoder_head_mask": decoder_head_mask,
140141
"cross_attn_head_mask": cross_attn_head_mask,
141142
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
143+
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
142144
}
143145

144146
# Copied from BartForConditionalGeneration

src/transformers/adapters/models/beit/adapter_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def forward(
4848
return_dict=return_dict,
4949
output_adapter_gating_scores=output_adapter_gating_scores,
5050
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
51+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
5152
)
5253

5354
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads

src/transformers/adapters/models/bert/adapter_model.py

+19
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def forward(
7272
return_dict=return_dict,
7373
output_adapter_gating_scores=output_adapter_gating_scores,
7474
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
75+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
7576
)
7677
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
7778
if not return_dict:
@@ -94,6 +95,24 @@ def forward(
9495
# in case no head is used just return the output of the base model (including pooler output)
9596
return outputs
9697

98+
# Copied from BertLMHeadModel
99+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
100+
input_shape = input_ids.shape
101+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
102+
if attention_mask is None:
103+
attention_mask = input_ids.new_ones(input_shape)
104+
105+
# cut decoder_input_ids if past is used
106+
if past is not None:
107+
input_ids = input_ids[:, -1:]
108+
109+
return {
110+
"input_ids": input_ids,
111+
"attention_mask": attention_mask,
112+
"past_key_values": past,
113+
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
114+
}
115+
97116
head_types = {
98117
"classification": ClassificationHead,
99118
"multilabel_classification": MultiLabelClassificationHead,

src/transformers/adapters/models/deberta/adapter_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def forward(
6565
return_dict=return_dict,
6666
output_adapter_gating_scores=output_adapter_gating_scores,
6767
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
68+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
6869
)
6970
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
7071
if not return_dict:

src/transformers/adapters/models/debertaV2/adapter_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def forward(
6868
return_dict=return_dict,
6969
output_adapter_gating_scores=output_adapter_gating_scores,
7070
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
71+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
7172
)
7273
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
7374
if not return_dict:

src/transformers/adapters/models/distilbert/adapter_model.py

+19
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def forward(
9494
return_dict=return_dict,
9595
output_adapter_gating_scores=output_adapter_gating_scores,
9696
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
97+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
9798
)
9899

99100
outputs = self.forward_head(
@@ -102,6 +103,24 @@ def forward(
102103

103104
return outputs
104105

106+
# Copied from RobertaForCausalLM
107+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
108+
input_shape = input_ids.shape
109+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
110+
if attention_mask is None:
111+
attention_mask = input_ids.new_ones(input_shape)
112+
113+
# cut decoder_input_ids if past is used
114+
if past is not None:
115+
input_ids = input_ids[:, -1:]
116+
117+
return {
118+
"input_ids": input_ids,
119+
"attention_mask": attention_mask,
120+
"past_key_values": past,
121+
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
122+
}
123+
105124
head_types = {
106125
"classification": ClassificationHead,
107126
"multilabel_classification": MultiLabelClassificationHead,

src/transformers/adapters/models/gpt2/adapter_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def forward(
8181
return_dict=return_dict,
8282
output_adapter_gating_scores=output_adapter_gating_scores,
8383
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
84+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
8485
)
8586

8687
batch_size = outputs[0].shape[0]
@@ -139,6 +140,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
139140
"position_ids": position_ids,
140141
"attention_mask": attention_mask,
141142
"token_type_ids": token_type_ids,
143+
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
142144
}
143145

144146
head_types = {

src/transformers/adapters/models/gptj/adapter_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def forward(
7777
return_dict=return_dict,
7878
output_adapter_gating_scores=output_adapter_gating_scores,
7979
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
80+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
8081
)
8182

8283
batch_size = outputs[0].shape[0]
@@ -135,6 +136,7 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
135136
"position_ids": position_ids,
136137
"attention_mask": attention_mask,
137138
"token_type_ids": token_type_ids,
139+
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
138140
}
139141

140142
head_types = {

src/transformers/adapters/models/mbart/adapter_model.py

+2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def forward(
8989
past_key_values=past_key_values,
9090
output_adapter_gating_scores=output_adapter_gating_scores,
9191
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
92+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
9293
)
9394
# sequence classification based on last token in sequence
9495
x = outputs[0] # last hidden state
@@ -139,6 +140,7 @@ def prepare_inputs_for_generation(
139140
"decoder_head_mask": decoder_head_mask,
140141
"cross_attn_head_mask": cross_attn_head_mask,
141142
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
143+
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
142144
}
143145

144146
# Copied from MBartForConditionalGeneration

src/transformers/adapters/models/roberta/adapter_model.py

+19
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def forward(
7777
return_dict=return_dict,
7878
output_adapter_gating_scores=output_adapter_gating_scores,
7979
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
80+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
8081
)
8182
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
8283
if not return_dict:
@@ -99,6 +100,24 @@ def forward(
99100
# in case no head is used just return the output of the base model (including pooler output)
100101
return outputs
101102

103+
# Copied from RobertaForCausalLM
104+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
105+
input_shape = input_ids.shape
106+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
107+
if attention_mask is None:
108+
attention_mask = input_ids.new_ones(input_shape)
109+
110+
# cut decoder_input_ids if past is used
111+
if past is not None:
112+
input_ids = input_ids[:, -1:]
113+
114+
return {
115+
"input_ids": input_ids,
116+
"attention_mask": attention_mask,
117+
"past_key_values": past,
118+
"adapter_input_parallelized": model_kwargs.pop("adapter_input_parallelized", False),
119+
}
120+
102121
head_types = {
103122
"classification": ClassificationHead,
104123
"multilabel_classification": MultiLabelClassificationHead,

src/transformers/adapters/models/t5/adapter_model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def forward(
8080
return_dict=return_dict,
8181
output_adapter_gating_scores=output_adapter_gating_scores,
8282
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
83+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
8384
)
8485
sequence_output = model_output[0]
8586
# ToDo move head to device for parallel forward pass
@@ -118,7 +119,6 @@ def prepare_inputs_for_generation(
118119
encoder_outputs=None,
119120
**kwargs
120121
):
121-
122122
# cut decoder_input_ids if past is used
123123
if past is not None:
124124
input_ids = input_ids[:, -1:]
@@ -132,6 +132,7 @@ def prepare_inputs_for_generation(
132132
"decoder_head_mask": decoder_head_mask,
133133
"cross_attn_head_mask": cross_attn_head_mask,
134134
"use_cache": use_cache,
135+
"adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False),
135136
}
136137

137138
# Copied from T5ForConditionalGeneration

src/transformers/adapters/models/vit/adapter_model.py

+1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def forward(
4848
return_dict=return_dict,
4949
output_adapter_gating_scores=output_adapter_gating_scores,
5050
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
51+
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
5152
)
5253

5354
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads

src/transformers/dependency_versions_table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"fugashi": "fugashi>=1.0",
2323
"GitPython": "GitPython<3.1.19",
2424
"hf-doc-builder": "hf-doc-builder>=0.3.0",
25-
"huggingface-hub": "huggingface-hub>=0.1.0,<0.8.0",
25+
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
2626
"importlib_metadata": "importlib_metadata",
2727
"ipadic": "ipadic>=1.0.0,<2.0",
2828
"isort": "isort>=5.5.4",

src/transformers/generation_utils.py

+12
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch.distributed as dist
2424
from torch import nn
2525

26+
from .adapters.composition import adjust_tensors_for_parallel
2627
from .adapters.context import ForwardContext
2728
from .generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
2829
from .generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
@@ -1198,6 +1199,17 @@ def generate(
11981199
# if decoder-only then inputs_tensor has to be `input_ids`
11991200
input_ids = inputs_tensor
12001201

1202+
# Pre-replicate inputs for parallel adapters to avoid issues within generation code
1203+
if (
1204+
hasattr(self.config, "adapters")
1205+
and self.config.adapters.active_setup
1206+
and self.config.adapters.active_setup.parallel_channels > 1
1207+
):
1208+
input_ids = input_ids.repeat(self.config.adapters.active_setup.parallel_channels, 1)
1209+
model_kwargs["adapter_input_parallelized"] = True
1210+
(attention_mask,) = adjust_tensors_for_parallel(input_ids, model_kwargs["attention_mask"])
1211+
model_kwargs["attention_mask"] = attention_mask
1212+
12011213
# 5. Prepare `max_length` depending on other stopping criteria.
12021214
input_ids_seq_length = input_ids.shape[-1]
12031215
if max_length is None and max_new_tokens is None:

tests_adapters/test_adapter_composition.py

+28
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Trainer,
1616
TrainingArguments,
1717
)
18+
from transformers.adapters import ADAPTER_MODEL_MAPPING
1819
from transformers.adapters.composition import BatchSplit, Fuse, Parallel, Split, Stack, parse_composition
1920
from transformers.testing_utils import require_torch, torch_device
2021

@@ -245,6 +246,33 @@ def test_batch_split_with_heads(self):
245246
)
246247
)
247248

249+
def test_parallel_generate(self):
250+
if self.config_class not in ADAPTER_MODEL_MAPPING or (
251+
not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_seq2seq_lm_head")
252+
and not hasattr(ADAPTER_MODEL_MAPPING[self.config_class], "add_causal_lm_head")
253+
):
254+
self.skipTest("No seq2seq or causal language model head")
255+
256+
model1 = AutoAdapterModel.from_config(self.config())
257+
model1.add_adapter("adapter1")
258+
model1.add_adapter("adapter2")
259+
if hasattr(model1, "add_seq2seq_lm_head"):
260+
model1.add_seq2seq_lm_head("adapter1")
261+
model1.add_seq2seq_lm_head("adapter2")
262+
else:
263+
model1.add_causal_lm_head("adapter1")
264+
model1.add_causal_lm_head("adapter2")
265+
model1.set_active_adapters(Parallel("adapter1", "adapter2"))
266+
model1.to(torch_device)
267+
268+
seq_output_length = 32
269+
270+
# Finally, also check if generation works properly
271+
input_ids = self.get_input_samples((1, 4), config=model1.config)["input_ids"]
272+
input_ids = input_ids.to(torch_device)
273+
generated = model1.generate(input_ids, max_length=seq_output_length)
274+
self.assertLessEqual(generated.shape, (2, seq_output_length))
275+
248276

249277
class ParallelTrainingMixin:
250278
def create_twin_adapters(self, model, name):

0 commit comments

Comments
 (0)