-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodules.py
More file actions
156 lines (135 loc) · 5.87 KB
/
modules.py
File metadata and controls
156 lines (135 loc) · 5.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import mlx.core as mx
import mlx.nn as nn
import numpy as np
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-8):
super().__init__()
self.eps = eps
self.gamma = mx.ones((dim,))
def __call__(self, x):
return x * mx.rsqrt(mx.mean(mx.square(x), axis=-1, keepdims=True) + self.eps) * self.gamma
class RotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
self.inv_freq = inv_freq
def __call__(self, x, seq_len):
t = mx.arange(seq_len, dtype=mx.float32)
return mx.outer(t, self.inv_freq)
def apply_rotary_pos_emb(q, k, freqs):
freqs = mx.repeat(freqs, 2, axis=-1)
cos, sin = mx.cos(freqs), mx.sin(freqs)
def rotate(x):
x1, x2 = x[..., 0::2], x[..., 1::2]
return mx.stack([-x2, x1], axis=-1).reshape(x.shape)
return q * cos + rotate(q) * sin, k * cos + rotate(k) * sin
def l2norm(t):
return t / mx.sqrt(mx.sum(mx.square(t), axis=-1, keepdims=True) + 1e-6)
class Attention(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, rotary_embed=None):
super().__init__()
self.heads, self.scale = heads, dim_head ** -0.5
inner_dim = heads * dim_head
self.norm = RMSNorm(dim)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
# Learnable temperature for Q scaling
self.temperature = mx.ones((heads, 1, 1))
self.to_gates = nn.Linear(dim, heads, bias=True)
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim, bias=False), nn.Dropout(dropout))
self.rotary_embed = rotary_embed
def __call__(self, x):
x = self.norm(x)
B, L, _ = x.shape
H = self.heads
qkv = self.to_qkv(x)
q, k, v = mx.split(qkv, 3, axis=-1)
# Reshape to (B, H, L, D)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
if self.rotary_embed:
freqs = self.rotary_embed(x, L)
q, k = apply_rotary_pos_emb(q, k, freqs)
# Fused SDPA
# Supports (B, H, L, D)
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale)
# Gating
gates = mx.sigmoid(self.to_gates(x)).transpose(0, 2, 1)[..., None]
out = out * gates
# Output projection
return self.to_out(out.transpose(0, 2, 1, 3).reshape(B, L, -1))
class FeedForward(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(RMSNorm(dim), nn.Linear(dim, inner_dim, bias=True), nn.GELU(), nn.Dropout(dropout), nn.Linear(inner_dim, dim, bias=True), nn.Dropout(dropout))
def __call__(self, x):
return self.net(x)
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mult, dropout, rotary_embed=None):
super().__init__()
self.layers = nn.Sequential(*[nn.Sequential(Attention(dim, heads, dim_head, dropout, rotary_embed), FeedForward(dim, mult, dropout)) for _ in range(depth)])
def __call__(self, x):
for layer in self.layers.layers:
attn, ff = layer.layers
x = attn(x) + x
x = ff(x) + x
return x
class MacaronFF(nn.Module):
def __init__(self, dim, mult=4, dropout=0.0):
super().__init__()
self.ff = FeedForward(dim, mult, dropout)
def __call__(self, x):
return self.ff(x) * 0.5
def MLP(dim_in, dim_out, dim_hidden, depth):
layers = []
curr_dim = dim_in
for i in range(depth):
out_d = dim_out if i == depth - 1 else dim_hidden
layers.append(nn.Linear(curr_dim, out_d))
if i < depth - 1: layers.append(nn.Tanh())
curr_dim = out_d
return layers
class ConformerConvModule(nn.Module):
def __init__(self, dim, expansion_factor=2, kernel_size=31, dropout=0.0):
super().__init__()
inner = dim * expansion_factor
self.net = nn.Sequential(
RMSNorm(dim), # 0
nn.Identity(), # 1 (Rearrange in PT)
nn.Conv1d(dim, inner * 2, 1), # 2
nn.GLU(axis=-1), # 3
nn.Conv1d(inner, inner, kernel_size, padding=(kernel_size - 1) // 2, groups=inner), # 4
nn.BatchNorm(num_features=inner), # 5
nn.SiLU(), # 6
nn.Conv1d(inner, dim, 1), # 7
nn.Identity(), # 8 (Rearrange in PT)
nn.Dropout(dropout) # 9
)
def __call__(self, x):
return self.net(x)
class ConformerBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult, attn_dropout, ff_dropout, conv_expansion_factor, conv_kernel_size, rotary_embed):
super().__init__()
self.ff1 = MacaronFF(dim, ff_mult, ff_dropout)
self.attn = Attention(dim, heads, dim_head, attn_dropout, rotary_embed)
self.conv = ConformerConvModule(dim, conv_expansion_factor, conv_kernel_size, ff_dropout)
self.ff2 = MacaronFF(dim, ff_mult, ff_dropout)
self.out_norm = RMSNorm(dim)
def __call__(self, x):
x = x + self.ff1(x)
x = x + self.attn(x)
x = x + self.conv(x)
x = x + self.ff2(x)
return self.out_norm(x)
class Conformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, ff_mult, attn_dropout, ff_dropout, conv_expansion_factor, conv_kernel_size, rotary_embed):
super().__init__()
self.layers = nn.Sequential(*[
ConformerBlock(dim, heads, dim_head, ff_mult, attn_dropout, ff_dropout, conv_expansion_factor, conv_kernel_size, rotary_embed)
for _ in range(depth)
])
def __call__(self, x):
for layer in self.layers.layers:
x = layer(x)
return x