1
1
import logging
2
+ import re
2
3
from typing import Any , Dict , Optional , Set , Tuple , Union
3
4
4
5
import peft
@@ -36,6 +37,9 @@ class UltravoxModel(transformers.LlamaPreTrainedModel):
36
37
config : UltravoxConfig # for type hinting
37
38
# Usually we load encoder and LLM weights from a pretrained model separately, so they are allowed to be missing
38
39
_keys_to_ignore_on_load_missing = ["audio_tower.*" , "language_model.*" ]
40
+ # Since we have kwargs in forward, we need to set this to False, otherwise grad_accum_steps will cause incorrect train loss to be reported
41
+ # see https://github.com/huggingface/transformers/issues/35856 and https://github.com/huggingface/trl/pull/2615/files
42
+ accepts_loss_kwargs = False
39
43
40
44
def __init__ (self , config : UltravoxConfig ):
41
45
super ().__init__ (config )
@@ -283,7 +287,7 @@ def _create_audio_tower(
283
287
cls , config : UltravoxConfig
284
288
) -> Union [transformers .Wav2Vec2Model , "ModifiedWhisperEncoder" ]:
285
289
if config .audio_model_id is not None :
286
- if "whisper" in config .audio_model_id is not None :
290
+ if "whisper" in config .audio_model_id . lower () :
287
291
audio_tower = ModifiedWhisperEncoder .from_pretrained (
288
292
config .audio_model_id , torch_dtype = config .torch_dtype
289
293
)
@@ -299,7 +303,7 @@ def _create_audio_tower(
299
303
config .audio_model_id , torch_dtype = config .torch_dtype
300
304
)
301
305
else :
302
- if "whisper" in config .audio_config ._name_or_path :
306
+ if "whisper" in config .audio_config ._name_or_path . lower () :
303
307
audio_tower = ModifiedWhisperEncoder (config .audio_config )
304
308
audio_tower .init_latency_mask (
305
309
config .audio_latency_block_size , dtype = config .torch_dtype
@@ -384,12 +388,11 @@ def merge_and_unload(self):
384
388
385
389
def push_to_hub (self , * args , ** kwargs ):
386
390
self .merge_and_unload ()
387
- self .to (self .language_model .dtype )
388
391
return super ().push_to_hub (* args , ** kwargs )
389
392
390
- def save_pretrained (
391
- self , * args , state_dict : Optional [Dict [str , Any ]] = None , ** kwargs
392
- ):
393
+ def diff_state_dict (
394
+ self , state_dict : Optional [Dict [str , Any ]] = None
395
+ ) -> Dict [ str , Any ] :
393
396
if state_dict is None :
394
397
state_dict = super ().state_dict ()
395
398
@@ -402,6 +405,13 @@ def save_pretrained(
402
405
or (k in named_params and named_params [k ].requires_grad )
403
406
}
404
407
408
+ return state_dict
409
+
410
+ def save_pretrained (
411
+ self , * args , state_dict : Optional [Dict [str , Any ]] = None , ** kwargs
412
+ ):
413
+ state_dict = self .diff_state_dict (state_dict )
414
+
405
415
super ().save_pretrained (* args , state_dict = state_dict , ** kwargs )
406
416
407
417
def _pre_load_state_dict_hook (self , state_dict : Dict [str , Any ], * args , ** kwargs ):
@@ -436,6 +446,7 @@ def print_trainable_parameters(self):
436
446
)
437
447
438
448
449
+ # TODO: refactor common parts to a shared module
439
450
def is_cache_empty (
440
451
past_key_values : Optional [Union [Tuple , transformers .cache_utils .Cache ]]
441
452
) -> bool :
@@ -453,12 +464,18 @@ def apply_lora(model: torch.nn.Module, lora_config: dict) -> torch.nn.Module:
453
464
"""
454
465
Applies LoRA finetuning to the model. If the `r` parameter is set to 0, the model is frozen instead.
455
466
"""
467
+ unfreeze_layers = lora_config .pop ("unfreeze_layers" , None )
456
468
lora_config = peft .LoraConfig (** lora_config or {})
457
469
458
470
if lora_config .r == 0 :
459
- # freeze the model entirely
460
- for param in model .parameters ():
461
- param .requires_grad = False
471
+ # freeze the model entirely, except for the specified layers
472
+ for name , param in model .named_parameters ():
473
+ if not unfreeze_layers or not any (
474
+ re .match (layer , name ) for layer in unfreeze_layers
475
+ ):
476
+ param .requires_grad = False
477
+ else :
478
+ logging .info (f"Unfreezing layer: { name } with #{ param .numel ()} params" )
462
479
else :
463
480
model = peft .get_peft_model (model , lora_config )
464
481
@@ -502,25 +519,35 @@ def forward(self, x):
502
519
return F .silu (gate ) * x
503
520
504
521
505
- class UltravoxProjector (nn .Sequential ):
522
+ class UltravoxProjector (nn .Module ):
506
523
def __init__ (self , config : UltravoxConfig ):
507
524
super ().__init__ ()
508
525
self .hidden_dim = config .hidden_size
509
526
self ._pad_and_stack = StackAudioFrames (config .stack_factor )
510
- dim = config .audio_config .hidden_size * config .stack_factor
511
- self .ln_pre = RMSNorm (dim , init = config .norm_init )
512
- self .linear_1 = nn .Linear (dim , self .hidden_dim , bias = False )
513
- dim = self .hidden_dim
527
+ dim_in = config .audio_config .hidden_size * config .stack_factor
528
+ self .ln_pre = RMSNorm (dim_in , init = config .norm_init )
529
+ self .linear_1 = nn .Linear (dim_in , self .hidden_dim , bias = False )
530
+ dim_mid = self .hidden_dim
514
531
self .act = transformers .activations .get_activation (config .projector_act )
515
- dim = dim // 2 if config .projector_act == "swiglu" else dim
516
- self .linear_2 = nn .Linear (dim , config .text_config .hidden_size , bias = False )
517
- self .ln_post = RMSNorm (config .text_config .hidden_size , init = config .norm_init )
532
+ dim_mid = dim_mid // 2 if config .projector_act == "swiglu" else dim_mid
533
+ dim_out = config .text_config .hidden_size
534
+ self .linear_2 = nn .Linear (dim_mid , dim_out , bias = False )
535
+
536
+ # Ultravox v0.4.1 and below uses layer_norm after the second linear layer,
537
+ # while v0.5.0 and above uses layer_norm after the first linear layer.
538
+ if config .projector_ln_mid :
539
+ self .ln_mid : nn .Module = RMSNorm (dim_mid , init = config .norm_init )
540
+ self .ln_post : nn .Module = nn .Identity ()
541
+ else :
542
+ self .ln_mid = nn .Identity ()
543
+ self .ln_post = RMSNorm (dim_out , init = config .norm_init )
518
544
519
545
def forward (self , audio_features : torch .Tensor ) -> torch .Tensor :
520
546
audio_features = self ._pad_and_stack (audio_features )
521
547
audio_features = self .ln_pre (audio_features )
522
548
hidden_states = self .linear_1 (audio_features )
523
549
hidden_states = self .act (hidden_states )
550
+ hidden_states = self .ln_mid (hidden_states )
524
551
hidden_states = self .linear_2 (hidden_states )
525
552
hidden_states = self .ln_post (hidden_states )
526
553
return hidden_states
@@ -544,6 +571,10 @@ class ModifiedWhisperEncoder(
544
571
base_model_prefix = "model.encoder"
545
572
_no_split_modules = ["WhisperEncoderLayer" ]
546
573
574
+ def __init__ (self , config : transformers .WhisperConfig ):
575
+ super ().__init__ (config )
576
+ self .config .is_decoder = False
577
+
547
578
def init_latency_mask (self , audio_latency_block_size : int , dtype : torch .dtype ):
548
579
if audio_latency_block_size is None :
549
580
self .audio_streaming_mask = None
0 commit comments