Skip to content

Commit ec03371

Browse files
authored
update ultravox model and config for v0.5 (#276)
1 parent 5c4c45e commit ec03371

File tree

3 files changed

+57
-26
lines changed

3 files changed

+57
-26
lines changed

ultravox/model/ultravox_config.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class LoraConfigSimplified:
1919
target_modules: Optional[List[str]] = dataclasses.field(
2020
default_factory=lambda: ["k_proj", "q_proj", "linear_k", "linear_q"]
2121
)
22+
# A list of module names regex patterns to unfreeze. Only used if r == 0.
23+
unfreeze_layers: Optional[List[str]] = None
2224

2325

2426
class LossFunction(str, Enum):
@@ -28,7 +30,7 @@ class LossFunction(str, Enum):
2830

2931
@dataclasses.dataclass
3032
class LossConfig:
31-
loss_function: LossFunction = LossFunction.KL_Divergence
33+
loss_function: LossFunction = LossFunction.CrossEntropy
3234
kl_temperature: float = 2.0
3335

3436
@property
@@ -70,7 +72,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
7072
Example:
7173
7274
```python
73-
>>> from transformers import UltravoxForConditionalGeneration, Wav2Vec2Config, UltravoxConfig, LlamaConfig
75+
>>> from transformers import UltravoxModel, Wav2Vec2Config, UltravoxConfig, LlamaConfig
7476
7577
>>> # Initializing an audio encoder config
7678
>>> audio_config = Wav2Vec2Config()
@@ -82,7 +84,7 @@ class UltravoxConfig(transformers.PretrainedConfig):
8284
>>> configuration = UltravoxConfig(audio_config, text_config)
8385
8486
>>> # Initializing a completely untrained model from the configuration
85-
>>> model = UltravoxForConditionalGeneration(configuration)
87+
>>> model = UltravoxModel(configuration)
8688
8789
>>> # Accessing the model configuration
8890
>>> configuration = model.config
@@ -105,6 +107,7 @@ def __init__(
105107
stack_factor: int = 8,
106108
norm_init: float = 0.4,
107109
projector_act: str = "swiglu",
110+
projector_ln_mid: bool = False, # defaults to False for compatibility with v0.4.1 and below
108111
text_model_lora_config: Optional[LoraConfigSimplified] = None,
109112
audio_model_lora_config: Optional[LoraConfigSimplified] = None,
110113
audio_latency_block_size: Optional[int] = None,
@@ -119,7 +122,7 @@ def __init__(
119122
self.stack_factor = stack_factor
120123
self.norm_init = norm_init
121124
self.projector_act = projector_act
122-
125+
self.projector_ln_mid = projector_ln_mid
123126
if text_model_id is not None:
124127
self.text_config: transformers.LlamaConfig = (
125128
transformers.AutoConfig.from_pretrained(text_model_id)

ultravox/model/ultravox_config_test.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,8 @@ def test_can_load_release(model_id: str):
1515
config_from_dict = ultravox_config.UltravoxConfig(**orig_config.to_dict())
1616
config_from_diff_dict = ultravox_config.UltravoxConfig(**orig_config.to_diff_dict())
1717
# To not inadvertently ignore other keys, we explicitly define keys we require to ignore.
18-
keys_to_ignore = ("audio_latency_block_size",)
19-
orig_values = {
20-
**{k: None for k in keys_to_ignore},
21-
**orig_config.to_dict(),
22-
}
18+
new_keys_default = {"audio_latency_block_size": None, "projector_ln_mid": False}
19+
orig_values = {**new_keys_default, **orig_config.to_dict()}
2320

2421
assert config_from_dict.to_dict() == orig_values
2522
assert config_from_diff_dict.to_dict() == orig_values

ultravox/model/ultravox_model.py

+48-17
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import re
23
from typing import Any, Dict, Optional, Set, Tuple, Union
34

45
import peft
@@ -36,6 +37,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
3637
config: UltravoxConfig # for type hinting
3738
# Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
3839
_keys_to_ignore_on_load_missing = ["audio_tower.*", "language_model.*"]
40+
# Since we have kwargs in forward, we need to set this to False, otherwise grad_accum_steps will cause incorrect train loss to be reported
41+
# see https://github.com/huggingface/transformers/issues/35856 and https://github.com/huggingface/trl/pull/2615/files
42+
accepts_loss_kwargs = False
3943

4044
def __init__(self, config: UltravoxConfig):
4145
super().__init__(config)
@@ -283,7 +287,7 @@ def _create_audio_tower(
283287
cls, config: UltravoxConfig
284288
) -> Union[transformers.Wav2Vec2Model, "ModifiedWhisperEncoder"]:
285289
if config.audio_model_id is not None:
286-
if "whisper" in config.audio_model_id is not None:
290+
if "whisper" in config.audio_model_id.lower():
287291
audio_tower = ModifiedWhisperEncoder.from_pretrained(
288292
config.audio_model_id, torch_dtype=config.torch_dtype
289293
)
@@ -299,7 +303,7 @@ def _create_audio_tower(
299303
config.audio_model_id, torch_dtype=config.torch_dtype
300304
)
301305
else:
302-
if "whisper" in config.audio_config._name_or_path:
306+
if "whisper" in config.audio_config._name_or_path.lower():
303307
audio_tower = ModifiedWhisperEncoder(config.audio_config)
304308
audio_tower.init_latency_mask(
305309
config.audio_latency_block_size, dtype=config.torch_dtype
@@ -384,12 +388,11 @@ def merge_and_unload(self):
384388

385389
def push_to_hub(self, *args, **kwargs):
386390
self.merge_and_unload()
387-
self.to(self.language_model.dtype)
388391
return super().push_to_hub(*args, **kwargs)
389392

390-
def save_pretrained(
391-
self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
392-
):
393+
def diff_state_dict(
394+
self, state_dict: Optional[Dict[str, Any]] = None
395+
) -> Dict[str, Any]:
393396
if state_dict is None:
394397
state_dict = super().state_dict()
395398

@@ -402,6 +405,13 @@ def save_pretrained(
402405
or (k in named_params and named_params[k].requires_grad)
403406
}
404407

408+
return state_dict
409+
410+
def save_pretrained(
411+
self, *args, state_dict: Optional[Dict[str, Any]] = None, **kwargs
412+
):
413+
state_dict = self.diff_state_dict(state_dict)
414+
405415
super().save_pretrained(*args, state_dict=state_dict, **kwargs)
406416

407417
def _pre_load_state_dict_hook(self, state_dict: Dict[str, Any], *args, **kwargs):
@@ -436,6 +446,7 @@ def print_trainable_parameters(self):
436446
)
437447

438448

449+
# TODO: refactor common parts to a shared module
439450
def is_cache_empty(
440451
past_key_values: Optional[Union[Tuple, transformers.cache_utils.Cache]]
441452
) -> bool:
@@ -453,12 +464,18 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
453464
"""
454465
Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
455466
"""
467+
unfreeze_layers = lora_config.pop("unfreeze_layers", None)
456468
lora_config = peft.LoraConfig(**lora_config or {})
457469

458470
if lora_config.r == 0:
459-
# freeze the model entirely
460-
for param in model.parameters():
461-
param.requires_grad = False
471+
# freeze the model entirely, except for the specified layers
472+
for name, param in model.named_parameters():
473+
if not unfreeze_layers or not any(
474+
re.match(layer, name) for layer in unfreeze_layers
475+
):
476+
param.requires_grad = False
477+
else:
478+
logging.info(f"Unfreezing layer: {name} with #{param.numel()} params")
462479
else:
463480
model = peft.get_peft_model(model, lora_config)
464481

@@ -502,25 +519,35 @@ def forward(self, x):
502519
return F.silu(gate) * x
503520

504521

505-
class UltravoxProjector(nn.Sequential):
522+
class UltravoxProjector(nn.Module):
506523
def __init__(self, config: UltravoxConfig):
507524
super().__init__()
508525
self.hidden_dim = config.hidden_size
509526
self._pad_and_stack = StackAudioFrames(config.stack_factor)
510-
dim = config.audio_config.hidden_size * config.stack_factor
511-
self.ln_pre = RMSNorm(dim, init=config.norm_init)
512-
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
513-
dim = self.hidden_dim
527+
dim_in = config.audio_config.hidden_size * config.stack_factor
528+
self.ln_pre = RMSNorm(dim_in, init=config.norm_init)
529+
self.linear_1 = nn.Linear(dim_in, self.hidden_dim, bias=False)
530+
dim_mid = self.hidden_dim
514531
self.act = transformers.activations.get_activation(config.projector_act)
515-
dim = dim // 2 if config.projector_act == "swiglu" else dim
516-
self.linear_2 = nn.Linear(dim, config.text_config.hidden_size, bias=False)
517-
self.ln_post = RMSNorm(config.text_config.hidden_size, init=config.norm_init)
532+
dim_mid = dim_mid // 2 if config.projector_act == "swiglu" else dim_mid
533+
dim_out = config.text_config.hidden_size
534+
self.linear_2 = nn.Linear(dim_mid, dim_out, bias=False)
535+
536+
# Ultravox v0.4.1 and below uses layer_norm after the second linear layer,
537+
# while v0.5.0 and above uses layer_norm after the first linear layer.
538+
if config.projector_ln_mid:
539+
self.ln_mid: nn.Module = RMSNorm(dim_mid, init=config.norm_init)
540+
self.ln_post: nn.Module = nn.Identity()
541+
else:
542+
self.ln_mid = nn.Identity()
543+
self.ln_post = RMSNorm(dim_out, init=config.norm_init)
518544

519545
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
520546
audio_features = self._pad_and_stack(audio_features)
521547
audio_features = self.ln_pre(audio_features)
522548
hidden_states = self.linear_1(audio_features)
523549
hidden_states = self.act(hidden_states)
550+
hidden_states = self.ln_mid(hidden_states)
524551
hidden_states = self.linear_2(hidden_states)
525552
hidden_states = self.ln_post(hidden_states)
526553
return hidden_states
@@ -544,6 +571,10 @@ class ModifiedWhisperEncoder(
544571
base_model_prefix = "model.encoder"
545572
_no_split_modules = ["WhisperEncoderLayer"]
546573

574+
def __init__(self, config: transformers.WhisperConfig):
575+
super().__init__(config)
576+
self.config.is_decoder = False
577+
547578
def init_latency_mask(self, audio_latency_block_size: int, dtype: torch.dtype):
548579
if audio_latency_block_size is None:
549580
self.audio_streaming_mask = None

0 commit comments

Comments
 (0)