Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoyang1998 committed Mar 15, 2024
1 parent 7ead73f commit 77bfecd
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 53 deletions.
8 changes: 2 additions & 6 deletions egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,13 +484,9 @@ def gigaspeech_subset_small_cuts(self) -> CutSet:
@lru_cache()
def gigaspeech_dev_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_DEV.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz")

@lru_cache()
def gigaspeech_test_cuts(self) -> CutSet:
logging.info("About to get Gigaspeech test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "cuts_TEST.jsonl.gz"
)
return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST.jsonl.gz")
2 changes: 1 addition & 1 deletion egs/librispeech/ASR/zipformer_lora/decode_gigaspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR,
)
from finetune import add_model_arguments, add_finetune_arguments, get_model, get_params
from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params

from icefall import ContextGraph, LmScorer, NgramLm
from icefall.checkpoint import (
Expand Down
4 changes: 2 additions & 2 deletions egs/librispeech/ASR/zipformer_lora/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@

import k2
import torch
from finetune import add_finetune_arguments, add_model_arguments, get_model, get_params
from scaling_converter import convert_scaled_to_non_scaled
from torch import Tensor, nn
from finetune import add_model_arguments, add_finetune_arguments, get_model, get_params

from icefall.checkpoint import (
average_checkpoints,
Expand Down Expand Up @@ -499,7 +499,7 @@ def main():
for k in param_names:
assert k in state_dict.keys()
new_state_dict[k] = state_dict[k]

base_model.load_state_dict(new_state_dict, strict=True)

model = base_model
Expand Down
18 changes: 8 additions & 10 deletions egs/librispeech/ASR/zipformer_lora/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,11 @@ def add_finetune_arguments(parser: argparse.ArgumentParser):
)

parser.add_argument(
"--use-lora",
type=str2bool,
default=True,
help="If use LoRA for fine-tune"
"--use-lora", type=str2bool, default=True, help="If use LoRA for fine-tune"
)

parser.add_argument(
"--lora-r",
type=int,
default=0,
help="The bottleneck dimension of LoRA"
"--lora-r", type=int, default=0, help="The bottleneck dimension of LoRA"
)

parser.add_argument(
Expand Down Expand Up @@ -1287,8 +1281,12 @@ def run(rank, world_size, args):
else:
p.requires_grad = False

logging.info("A total of {} trainable parameters ({:.3f}% of the whole model)".format(num_trainable, num_trainable/num_param * 100))

logging.info(
"A total of {} trainable parameters ({:.3f}% of the whole model)".format(
num_trainable, num_trainable / num_param * 100
)
)

model.to(device)
if world_size > 1:
logging.info("Using DDP")
Expand Down
70 changes: 44 additions & 26 deletions egs/librispeech/ASR/zipformer_lora/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
# limitations under the License.


from typing import Optional, Tuple, Union
import logging
import k2
from torch.cuda.amp import custom_fwd, custom_bwd
import math
import random
from typing import Optional, Tuple, Union

import k2
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd


def logaddexp_onnx(x: Tensor, y: Tensor) -> Tensor:
Expand Down Expand Up @@ -518,42 +519,49 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear:
torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale)
return ans


class LoRALayer:
def __init__(
self,
r: int,
lora_alpha: int,
self,
r: int,
lora_alpha: int,
lora_dropout: float,
merge_weights: bool,
):
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
if lora_dropout > 0.:
if lora_dropout > 0.0:
self.lora_dropout = nn.Dropout(p=lora_dropout)
else:
self.lora_dropout = lambda x: x
# Mark the weight as unmerged
self.merged = False
self.merge_weights = merge_weights


class ScaledLinear_lora(nn.Linear, LoRALayer):
def __init__(
self,
in_features: int,
out_features: int,
r: int=0,
fan_in_fan_out: bool=False,
lora_alpha: int=1,
lora_dropout: float=0.0,
r: int = 0,
fan_in_fan_out: bool = False,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
initial_scale: float = 1.0,
merge_weights: bool = True,
**kwargs,
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
merge_weights=merge_weights)

LoRALayer.__init__(
self,
r=r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
merge_weights=merge_weights,
)

self.initial_scale = initial_scale
self.fan_in_fan_out = fan_in_fan_out
if r > 0:
Expand All @@ -563,7 +571,7 @@ def __init__(
self.weight.requires_grad = False

self.reset_parameters()

def reset_parameters(self):
# initialize the parameters
nn.Linear.reset_parameters(self)
Expand All @@ -572,16 +580,19 @@ def reset_parameters(self):
with torch.no_grad():
self.weight[:] *= initial_scale
if self.bias is not None:
nn.init.uniform_(self.bias, -0.1 * initial_scale, 0.1 * initial_scale)
if hasattr(self, 'lora_A'):
nn.init.uniform_(
self.bias, -0.1 * initial_scale, 0.1 * initial_scale
)
if hasattr(self, "lora_A"):
# initialize B the same way as the default for nn.Linear and A to zero
# this is different than what is described in the paper but should not affect performance
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
nn.init.zeros_(self.lora_B)
def train(self, mode: bool=True):

def train(self, mode: bool = True):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w

nn.Linear.train(self, mode)
if mode:
# We don't want the weights to be merged in training mode
Expand All @@ -595,18 +606,24 @@ def T(w):
# Merge the weights and mark it
if self.r > 0:
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
self.merged = True
self.merged = True

def forward(self, x: torch.Tensor):
def T(w):
return w.transpose(0, 1) if self.fan_in_fan_out else w

if self.r > 0 and not self.merged:
result = F.linear(x, T(self.weight), bias=self.bias)
delta_result = self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)
delta_result = (
self.lora_dropout(x)
@ self.lora_A.transpose(0, 1)
@ self.lora_B.transpose(0, 1)
)
return result + delta_result * self.scaling
else:
return F.linear(x, T(self.weight), bias=self.bias)


def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d:
"""
Behaves like a constructor of a modified version of nn.Conv1d
Expand Down Expand Up @@ -1740,6 +1757,7 @@ def forward(self, x: Tensor):
self.dropout_shared_dim,
)


class ActivationDropoutAndLinear_lora(torch.nn.Module):
def __init__(
self,
Expand All @@ -1749,9 +1767,9 @@ def __init__(
activation: str = "SwooshL",
dropout_p: FloatLike = 0.0,
dropout_shared_dim: Optional[int] = -1,
r: int=0,
lora_alpha: int=1,
lora_dropout: float=0.0,
r: int = 0,
lora_alpha: int = 1,
lora_dropout: float = 0.0,
initial_scale: float = 1.0,
):
super().__init__()
Expand Down
16 changes: 8 additions & 8 deletions egs/librispeech/ASR/zipformer_lora/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
)
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
ScaledLinear_lora
)
from scaling import (
ActivationDropoutAndLinear,
Expand All @@ -40,6 +39,7 @@
ChunkCausalDepthwiseConv1d,
Dropout2,
FloatLike,
ScaledLinear_lora,
ScheduledFloat,
Whiten,
convert_num_channels,
Expand Down Expand Up @@ -636,7 +636,7 @@ def __init__(
)

self.self_attn1 = SelfAttention(
embed_dim,
embed_dim,
num_heads,
value_head_dim,
lora_r=lora_r,
Expand All @@ -645,7 +645,7 @@ def __init__(
)

self.self_attn2 = SelfAttention(
embed_dim,
embed_dim,
num_heads,
value_head_dim,
lora_r=lora_r,
Expand All @@ -654,7 +654,7 @@ def __init__(
)

self.feed_forward1 = FeedforwardModule(
embed_dim,
embed_dim,
(feedforward_dim * 3) // 4,
dropout,
lora_r=lora_r,
Expand All @@ -672,7 +672,7 @@ def __init__(
)

self.feed_forward3 = FeedforwardModule(
embed_dim,
embed_dim,
(feedforward_dim * 5) // 4,
dropout,
lora_r=lora_r,
Expand Down Expand Up @@ -1566,7 +1566,7 @@ def __init__(
pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.0)),
lora_r: int = 0,
lora_alpha: int = 4,
lora_dropout: float=0.0
lora_dropout: float = 0.0,
) -> None:
super().__init__()
self.embed_dim = embed_dim
Expand Down Expand Up @@ -1935,7 +1935,7 @@ def __init__(
value_head_dim: int,
lora_r: int = 0,
lora_alpha: int = 4,
lora_dropout: float=0.0
lora_dropout: float = 0.0,
) -> None:
super().__init__()
self.in_proj = ScaledLinear_lora(
Expand Down Expand Up @@ -2064,7 +2064,7 @@ def __init__(
dropout: FloatLike,
lora_r: int = 0,
lora_alpha: int = 4,
lora_dropout: float=0.0
lora_dropout: float = 0.0,
):
super(FeedforwardModule, self).__init__()
self.in_proj = ScaledLinear_lora(
Expand Down

0 comments on commit 77bfecd

Please sign in to comment.