Update modeling_internlm.py
Browse files- modeling_internlm.py +14 -37
modeling_internlm.py
CHANGED
@@ -86,17 +86,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|
86 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
87 |
|
88 |
|
89 |
-
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
90 |
-
"""
|
91 |
-
(batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
92 |
-
"""
|
93 |
-
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
94 |
-
if n_rep == 1:
|
95 |
-
return hidden_states
|
96 |
-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
97 |
-
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
98 |
-
|
99 |
-
|
100 |
class InternLMRMSNorm(nn.Module):
|
101 |
"""RMSNorm implemention."""
|
102 |
|
@@ -272,8 +261,6 @@ class InternLMAttention(nn.Module):
|
|
272 |
self.hidden_size = config.hidden_size
|
273 |
self.num_heads = config.num_attention_heads
|
274 |
self.head_dim = self.hidden_size // self.num_heads
|
275 |
-
self.num_key_value_heads = config.num_key_value_heads
|
276 |
-
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
277 |
self.max_position_embeddings = config.max_position_embeddings
|
278 |
|
279 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
@@ -282,30 +269,27 @@ class InternLMAttention(nn.Module):
|
|
282 |
f" and `num_heads`: {self.num_heads})."
|
283 |
)
|
284 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
285 |
-
self.k_proj = nn.Linear(self.hidden_size, self.
|
286 |
-
self.v_proj = nn.Linear(self.hidden_size, self.
|
287 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
288 |
self.rotary_emb = self._init_rope()
|
289 |
|
290 |
def _init_rope(self):
|
291 |
-
if self.config.
|
292 |
self.rotary_emb = InternLMRotaryEmbedding(
|
293 |
self.head_dim,
|
294 |
max_position_embeddings=self.max_position_embeddings,
|
295 |
-
base=self.config.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
296 |
)
|
297 |
else:
|
298 |
-
|
299 |
-
scaling_factor = self.config.rope_scaling["factor"]
|
300 |
-
if scaling_type == "dynamic":
|
301 |
-
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
|
302 |
-
self.head_dim,
|
303 |
-
max_position_embeddings=self.max_position_embeddings,
|
304 |
-
base=self.config.rope_theta,
|
305 |
-
scaling_factor=scaling_factor,
|
306 |
-
)
|
307 |
-
else:
|
308 |
-
raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
|
309 |
return self.rotary_emb
|
310 |
|
311 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
@@ -323,12 +307,8 @@ class InternLMAttention(nn.Module):
|
|
323 |
bsz, q_len, _ = hidden_states.size()
|
324 |
|
325 |
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
326 |
-
key_states = (
|
327 |
-
|
328 |
-
)
|
329 |
-
value_states = (
|
330 |
-
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
331 |
-
)
|
332 |
|
333 |
if past_key_value is not None:
|
334 |
# reuse k, v, self_attention
|
@@ -341,9 +321,6 @@ class InternLMAttention(nn.Module):
|
|
341 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
342 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
343 |
|
344 |
-
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
345 |
-
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
346 |
-
|
347 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
348 |
|
349 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
|
|
86 |
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
87 |
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
class InternLMRMSNorm(nn.Module):
|
90 |
"""RMSNorm implemention."""
|
91 |
|
|
|
261 |
self.hidden_size = config.hidden_size
|
262 |
self.num_heads = config.num_attention_heads
|
263 |
self.head_dim = self.hidden_size // self.num_heads
|
|
|
|
|
264 |
self.max_position_embeddings = config.max_position_embeddings
|
265 |
|
266 |
if (self.head_dim * self.num_heads) != self.hidden_size:
|
|
|
269 |
f" and `num_heads`: {self.num_heads})."
|
270 |
)
|
271 |
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
272 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
273 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
274 |
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
275 |
self.rotary_emb = self._init_rope()
|
276 |
|
277 |
def _init_rope(self):
|
278 |
+
if self.config.rotary["type"] == "origin"
|
279 |
self.rotary_emb = InternLMRotaryEmbedding(
|
280 |
self.head_dim,
|
281 |
max_position_embeddings=self.max_position_embeddings,
|
282 |
+
base=self.config.rotary["base"],
|
283 |
+
)
|
284 |
+
elif self.config.rotary["type"] == "dynamic":
|
285 |
+
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
|
286 |
+
self.head_dim,
|
287 |
+
max_position_embeddings=self.max_position_embeddings,
|
288 |
+
base=self.config.rotary["base"],
|
289 |
+
scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
|
290 |
)
|
291 |
else:
|
292 |
+
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
return self.rotary_emb
|
294 |
|
295 |
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
|
307 |
bsz, q_len, _ = hidden_states.size()
|
308 |
|
309 |
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
310 |
+
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
311 |
+
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
312 |
|
313 |
if past_key_value is not None:
|
314 |
# reuse k, v, self_attention
|
|
|
321 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
322 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
323 |
|
|
|
|
|
|
|
324 |
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
325 |
|
326 |
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|