Skip to content

Commit 25797a0

Browse files
authored
Use default head dropout prob if not provided by model (#685)
Fixes #666 and issue described in #523.
1 parent 43ccd9f commit 25797a0

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

src/adapters/heads/base.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,23 @@ def __init__(self, name):
6565
self.config = {}
6666
self.name = name
6767

68-
def build(self, model):
69-
model_config = model.config
70-
pred_head = []
68+
def _get_dropout_prob(self, model_config):
69+
# try to infer dropout prob from various sources, default to 0.0
7170
if "dropout_prob" in self.config and self.config["dropout_prob"] is not None:
7271
dropout_prob = self.config["dropout_prob"]
7372
elif hasattr(model_config, "classifier_dropout") and model_config.classifier_dropout is not None:
7473
dropout_prob = model_config.classifier_dropout
75-
else:
74+
elif hasattr(model_config, "hidden_dropout_prob") and model_config.hidden_dropout_prob is not None:
7675
dropout_prob = model_config.hidden_dropout_prob
76+
else:
77+
dropout_prob = 0.0
78+
79+
return dropout_prob
80+
81+
def build(self, model):
82+
model_config = model.config
83+
pred_head = []
84+
dropout_prob = self._get_dropout_prob(model_config)
7785
bias = self.config.get("bias", True)
7886
for l_id in range(self.config["layers"]):
7987
pred_head.append(nn.Dropout(dropout_prob))

src/adapters/heads/dependency_parsing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def build(self, model):
8181
n_in=model.config.hidden_size, n_out=self.config["num_labels"], bias_x=True, bias_y=True
8282
)
8383

84-
self.dropout = nn.Dropout(model.config.hidden_dropout_prob)
84+
self.dropout = nn.Dropout(self._get_dropout_prob(model.config))
8585

8686
self.loss_fn = CrossEntropyLoss()
8787

tests/test_llama.py

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ class LlamaAdapterTestBase(AdapterTestBase):
2929
num_attention_heads=4,
3030
intermediate_size=37,
3131
hidden_act="gelu",
32-
hidden_dropout_prob=0.1,
3332
pad_token_id=0,
3433
)
3534
tokenizer_name = "openlm-research/open_llama_13b"

0 commit comments

Comments
 (0)