x54-729 commited on
Commit
41d18d0
1 Parent(s): d5770d1

Update modeling_internlm.py

Browse files
Files changed (1) hide show
  1. modeling_internlm.py +14 -37
modeling_internlm.py CHANGED
@@ -86,17 +86,6 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
86
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
87
 
88
 
89
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
90
- """
91
- (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
92
- """
93
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
94
- if n_rep == 1:
95
- return hidden_states
96
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
97
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
98
-
99
-
100
  class InternLMRMSNorm(nn.Module):
101
  """RMSNorm implemention."""
102
 
@@ -272,8 +261,6 @@ class InternLMAttention(nn.Module):
272
  self.hidden_size = config.hidden_size
273
  self.num_heads = config.num_attention_heads
274
  self.head_dim = self.hidden_size // self.num_heads
275
- self.num_key_value_heads = config.num_key_value_heads
276
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
277
  self.max_position_embeddings = config.max_position_embeddings
278
 
279
  if (self.head_dim * self.num_heads) != self.hidden_size:
@@ -282,30 +269,27 @@ class InternLMAttention(nn.Module):
282
  f" and `num_heads`: {self.num_heads})."
283
  )
284
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
285
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
286
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
287
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
288
  self.rotary_emb = self._init_rope()
289
 
290
  def _init_rope(self):
291
- if self.config.rope_scaling is None:
292
  self.rotary_emb = InternLMRotaryEmbedding(
293
  self.head_dim,
294
  max_position_embeddings=self.max_position_embeddings,
295
- base=self.config.rope_theta,
 
 
 
 
 
 
 
296
  )
297
  else:
298
- scaling_type = self.config.rope_scaling["type"]
299
- scaling_factor = self.config.rope_scaling["factor"]
300
- if scaling_type == "dynamic":
301
- self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
302
- self.head_dim,
303
- max_position_embeddings=self.max_position_embeddings,
304
- base=self.config.rope_theta,
305
- scaling_factor=scaling_factor,
306
- )
307
- else:
308
- raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
309
  return self.rotary_emb
310
 
311
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@@ -323,12 +307,8 @@ class InternLMAttention(nn.Module):
323
  bsz, q_len, _ = hidden_states.size()
324
 
325
  query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
326
- key_states = (
327
- self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
328
- )
329
- value_states = (
330
- self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
331
- )
332
 
333
  if past_key_value is not None:
334
  # reuse k, v, self_attention
@@ -341,9 +321,6 @@ class InternLMAttention(nn.Module):
341
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
342
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
343
 
344
- key_states = repeat_kv(key_states, self.num_key_value_groups)
345
- value_states = repeat_kv(value_states, self.num_key_value_groups)
346
-
347
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
348
 
349
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
 
86
  return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
89
  class InternLMRMSNorm(nn.Module):
90
  """RMSNorm implemention."""
91
 
 
261
  self.hidden_size = config.hidden_size
262
  self.num_heads = config.num_attention_heads
263
  self.head_dim = self.hidden_size // self.num_heads
 
 
264
  self.max_position_embeddings = config.max_position_embeddings
265
 
266
  if (self.head_dim * self.num_heads) != self.hidden_size:
 
269
  f" and `num_heads`: {self.num_heads})."
270
  )
271
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
272
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
273
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
274
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
275
  self.rotary_emb = self._init_rope()
276
 
277
  def _init_rope(self):
278
+ if self.config.rotary["type"] == "origin"
279
  self.rotary_emb = InternLMRotaryEmbedding(
280
  self.head_dim,
281
  max_position_embeddings=self.max_position_embeddings,
282
+ base=self.config.rotary["base"],
283
+ )
284
+ elif self.config.rotary["type"] == "dynamic":
285
+ self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
286
+ self.head_dim,
287
+ max_position_embeddings=self.max_position_embeddings,
288
+ base=self.config.rotary["base"],
289
+ scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
290
  )
291
  else:
292
+ raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
 
 
 
 
 
 
 
 
 
 
293
  return self.rotary_emb
294
 
295
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
 
307
  bsz, q_len, _ = hidden_states.size()
308
 
309
  query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
310
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
311
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
312
 
313
  if past_key_value is not None:
314
  # reuse k, v, self_attention
 
321
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
322
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
323
 
 
 
 
324
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
325
 
326
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):