Update model.py
Browse files
model.py
CHANGED
@@ -193,7 +193,8 @@ class FlaxFalconAttention(nn.Module):
|
|
193 |
def setup(self) -> None:
|
194 |
head_dim = self.config.hidden_size // self.config.n_head
|
195 |
self.w_qkv = nn.Dense(
|
196 |
-
features=self.config.hidden_size
|
|
|
197 |
dtype=self.dtype,
|
198 |
param_dtype=self.param_dtype,
|
199 |
use_bias=self.config.bias
|
@@ -206,6 +207,7 @@ class FlaxFalconAttention(nn.Module):
|
|
206 |
use_bias=self.config.bias
|
207 |
)
|
208 |
self.head_dim = head_dim
|
|
|
209 |
if not self.config.alibi:
|
210 |
self.freq = precompute_freqs_cis(head_dim, self.config.max_seq_len, dtype=self.dtype)
|
211 |
|
@@ -215,13 +217,25 @@ class FlaxFalconAttention(nn.Module):
|
|
215 |
attention_mask: jnp.DeviceArray = None,
|
216 |
):
|
217 |
b, s, d = hidden_states.shape
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
if not self.config.alibi:
|
226 |
freq = self.freq[:s].reshape(1, s, -1)
|
227 |
q, k = apply_rotary_emb(q, k, freq, self.dtype)
|
@@ -231,8 +245,10 @@ class FlaxFalconAttention(nn.Module):
|
|
231 |
if alibi is not None:
|
232 |
attn += attn
|
233 |
attn = attn * self.factor_scale
|
|
|
234 |
if attention_mask is not None:
|
235 |
attn += attention_mask
|
|
|
236 |
attn = jax.nn.softmax(attn, axis=-1)
|
237 |
attn = jnp.einsum('...hqk,...khd->...qhd', attn, v, precision=self.precision).reshape((b, s, d))
|
238 |
return self.wo(attn)
|
|
|
193 |
def setup(self) -> None:
|
194 |
head_dim = self.config.hidden_size // self.config.n_head
|
195 |
self.w_qkv = nn.Dense(
|
196 |
+
features=3 * self.config.hidden_size if not self.config.multi_query else (
|
197 |
+
self.config.hidden_size + 2 * head_dim),
|
198 |
dtype=self.dtype,
|
199 |
param_dtype=self.param_dtype,
|
200 |
use_bias=self.config.bias
|
|
|
207 |
use_bias=self.config.bias
|
208 |
)
|
209 |
self.head_dim = head_dim
|
210 |
+
assert self.head_dim * self.config.n_head == self.config.hidden_size
|
211 |
if not self.config.alibi:
|
212 |
self.freq = precompute_freqs_cis(head_dim, self.config.max_seq_len, dtype=self.dtype)
|
213 |
|
|
|
217 |
attention_mask: jnp.DeviceArray = None,
|
218 |
):
|
219 |
b, s, d = hidden_states.shape
|
220 |
+
qkv = self.w_qkv(hidden_states)
|
221 |
+
if not self.config.multi_query:
|
222 |
+
q, k, v = jnp.split(qkv, 3, -1)
|
223 |
+
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
224 |
+
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
225 |
+
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
226 |
+
k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
|
227 |
+
q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
|
228 |
+
v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
|
229 |
+
else:
|
230 |
+
qkv = qkv.reshape(
|
231 |
+
b, s, self.config.n_head + 2, -1
|
232 |
+
)
|
233 |
+
q, k, v = qkv[..., :-2, :], qkv[..., [-2], :], qkv[..., [-1], :]
|
234 |
+
|
235 |
+
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
|
236 |
+
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
|
237 |
+
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
|
238 |
+
|
239 |
if not self.config.alibi:
|
240 |
freq = self.freq[:s].reshape(1, s, -1)
|
241 |
q, k = apply_rotary_emb(q, k, freq, self.dtype)
|
|
|
245 |
if alibi is not None:
|
246 |
attn += attn
|
247 |
attn = attn * self.factor_scale
|
248 |
+
|
249 |
if attention_mask is not None:
|
250 |
attn += attention_mask
|
251 |
+
|
252 |
attn = jax.nn.softmax(attn, axis=-1)
|
253 |
attn = jnp.einsum('...hqk,...khd->...qhd', attn, v, precision=self.precision).reshape((b, s, d))
|
254 |
return self.wo(attn)
|