BeardedMonster commited on
Commit
93fd473
1 Parent(s): 4f4f5b6

Upload GPTJXForCausalLM

Browse files
Files changed (1) hide show
  1. pretrained_model.py +144 -47
pretrained_model.py CHANGED
@@ -1,7 +1,7 @@
1
  from transformers import AutoConfig, PreTrainedModel, AutoModelForCausalLM
2
  from typing import List, Optional
3
  from torch import nn
4
- from model import LayerNorm, BlockJ
5
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
  import torch
7
  import math
@@ -10,6 +10,103 @@ from transformers import AutoConfig, AutoModel
10
  from .pretrained_config import *
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  class GPTJXForCausalLM(PreTrainedModel):
15
  config_class = GPTJXConfig
@@ -117,36 +214,36 @@ class GPTJXForCausalLM(PreTrainedModel):
117
  return model_inputs
118
 
119
 
120
- @torch.no_grad()
121
- def stream(self, idx, max_new_tokens, temperature=1.0, top_k=None,gen_mode="greedy"):
122
- """
123
- Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
124
- the sequence max_new_tokens times, feeding the predictions back into the model each time.
125
- Most likely you'll want to make sure to be in model.eval() mode of operation for this.
126
- """
127
- for _ in range(max_new_tokens):
128
- # if the sequence context is growing too long we must crop it at block_size
129
- idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
130
- # forward the model to get the logits for the index in the sequence
131
- logits, _ = self(idx_cond, eval=True)
132
- # pluck the logits at the final step and scale by desired temperature
133
- logits = logits[:, -1, :] / temperature
134
- # optionally crop the logits to only the top k options
135
- if top_k is not None:
136
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
137
- logits[logits < v[:, [-1]]] = -float('Inf')
138
- # apply softmax to convert logits to (normalized) probabilities
139
- probs = F.softmax(logits, dim=-1)
140
- # sample from the distribution
141
- if gen_mode == 'greedy':
142
- idx_next = torch.argmax(probs, dim=-1).unsqueeze(0)
143
 
144
- else:
145
- idx_next = torch.multinomial(probs, num_samples=1)
146
- # print(idx_next.shape, idx.shape)
147
- idx = torch.cat((idx, idx_next), dim=1)
148
- # append sampled index to the running sequence and continue
149
- yield idx_next
150
 
151
 
152
  def crop_block_size(self, block_size):
@@ -166,23 +263,23 @@ AutoModel.register(GPTJXConfig,GPTJXForCausalLM)
166
  AutoModelForCausalLM.register(GPTJXConfig, GPTJXForCausalLM)
167
 
168
 
169
- if __name__ == '__main__':
170
- from transformers import AutoTokenizer
171
 
172
- tokenizer = AutoTokenizer.from_pretrained("BeardedMonster/SabiYarn")
173
- input_ids = tokenizer("ba wo ni?", return_tensors="pt")["input_ids"]
174
- targets = input_ids
175
 
176
- # config = GPTJConfig()
177
- # config.save_pretrained("gptj-config")
178
- # new_config = GPTJ.from_pretrained("gptj-config")
179
- # model = GPTJ(config)
180
- # state_dict = torch.load('model.pt', map_location="cpu")
181
- # model.load_state_dict(state_dict)
182
- model = GPTJXForCausalLM.from_pretrained("/pretrainedmodel")
183
- # model.save_pretrained("/pretrainedmodel")
184
- outputs = model(input_ids, targets)
185
- print(outputs)
186
- output = model.generate(input_ids, max_new_tokens=100)
187
- print(tokenizer.decode(output[0]))
188
  # print(new_config)
 
1
  from transformers import AutoConfig, PreTrainedModel, AutoModelForCausalLM
2
  from typing import List, Optional
3
  from torch import nn
4
+ # from model import LayerNorm, BlockJ
5
  from transformers.modeling_outputs import CausalLMOutputWithPast
6
  import torch
7
  import math
 
10
  from .pretrained_config import *
11
 
12
 
13
+
14
+ class LayerNorm(nn.Module):
15
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
16
+
17
+ def __init__(self, ndim, bias):
18
+ super().__init__()
19
+ self.weight = nn.Parameter(torch.ones(ndim))
20
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
21
+
22
+ def forward(self, input):
23
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
24
+
25
+ class CausalSelfAttention(nn.Module):
26
+
27
+ def __init__(self, config):
28
+ super().__init__()
29
+ assert config.n_embd % config.n_head == 0
30
+ # key, query, value projections for all heads, but in a batch
31
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
32
+ # output projection
33
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
34
+ # regularization
35
+ self.attn_dropout = nn.Dropout(config.dropout)
36
+ self.resid_dropout = nn.Dropout(config.dropout)
37
+ self.n_head = config.n_head
38
+ self.n_embd = config.n_embd
39
+ self.dropout = config.dropout
40
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
41
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
42
+ # if not self.flash:
43
+ # print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
44
+ # causal mask to ensure that attention is only applied to the left in the input sequence
45
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
46
+ .view(1, 1, config.block_size, config.block_size))
47
+
48
+ def forward(self, x, attn_mask=None):
49
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
50
+
51
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
52
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
53
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
54
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
55
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
56
+
57
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
58
+ if self.flash:
59
+ if attn_mask is not None:
60
+ # efficient attention using Flash Attention CUDA kernels
61
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0)
62
+ else:
63
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
64
+ else:
65
+ # manual implementation of attention
66
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
67
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
68
+ att = F.softmax(att, dim=-1)
69
+ att = self.attn_dropout(att)
70
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
71
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
72
+
73
+ # output projection
74
+ y = self.resid_dropout(self.c_proj(y))
75
+ return y
76
+
77
+ class MLP(nn.Module):
78
+
79
+ def __init__(self, config):
80
+ super().__init__()
81
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
82
+ self.gelu = nn.GELU()
83
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
84
+ self.dropout = nn.Dropout(config.dropout)
85
+
86
+ def forward(self, x):
87
+ x = self.c_fc(x)
88
+ x = self.gelu(x)
89
+ x = self.c_proj(x)
90
+ x = self.dropout(x)
91
+ return x
92
+
93
+ class BlockJ(nn.Module):
94
+
95
+ def __init__(self, config):
96
+ super().__init__()
97
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
98
+ self.j = LayerNorm(config.n_embd, config.n_embd)
99
+ self.attn = CausalSelfAttention(config)
100
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101
+ self.mlp = MLP(config)
102
+
103
+ def forward(self, x, attn_mask=None):
104
+ h = x
105
+ x = self.ln_1(x)
106
+ x = h + self.attn(x, attn_mask) + self.j(x)
107
+ x = x + self.mlp(self.ln_2(x))
108
+ return x
109
+
110
 
111
  class GPTJXForCausalLM(PreTrainedModel):
112
  config_class = GPTJXConfig
 
214
  return model_inputs
215
 
216
 
217
+ # @torch.no_grad()
218
+ # def stream(self, idx, max_new_tokens, temperature=1.0, top_k=None,gen_mode="greedy"):
219
+ # """
220
+ # Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
221
+ # the sequence max_new_tokens times, feeding the predictions back into the model each time.
222
+ # Most likely you'll want to make sure to be in model.eval() mode of operation for this.
223
+ # """
224
+ # for _ in range(max_new_tokens):
225
+ # # if the sequence context is growing too long we must crop it at block_size
226
+ # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
227
+ # # forward the model to get the logits for the index in the sequence
228
+ # logits, _ = self(idx_cond, eval=True)
229
+ # # pluck the logits at the final step and scale by desired temperature
230
+ # logits = logits[:, -1, :] / temperature
231
+ # # optionally crop the logits to only the top k options
232
+ # if top_k is not None:
233
+ # v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
234
+ # logits[logits < v[:, [-1]]] = -float('Inf')
235
+ # # apply softmax to convert logits to (normalized) probabilities
236
+ # probs = F.softmax(logits, dim=-1)
237
+ # # sample from the distribution
238
+ # if gen_mode == 'greedy':
239
+ # idx_next = torch.argmax(probs, dim=-1).unsqueeze(0)
240
 
241
+ # else:
242
+ # idx_next = torch.multinomial(probs, num_samples=1)
243
+ # # print(idx_next.shape, idx.shape)
244
+ # idx = torch.cat((idx, idx_next), dim=1)
245
+ # # append sampled index to the running sequence and continue
246
+ # yield idx_next
247
 
248
 
249
  def crop_block_size(self, block_size):
 
263
  AutoModelForCausalLM.register(GPTJXConfig, GPTJXForCausalLM)
264
 
265
 
266
+ # if __name__ == '__main__':
267
+ # from transformers import AutoTokenizer
268
 
269
+ # tokenizer = AutoTokenizer.from_pretrained("BeardedMonster/SabiYarn")
270
+ # input_ids = tokenizer("Awọn eeyan Cairo, ni Egypt ti bẹrẹ si n to lawọn ileesẹ to n ṣe burẹdi bayii.", return_tensors="pt")["input_ids"]
271
+ # targets = input_ids
272
 
273
+ # # config = GPTJConfig()
274
+ # # config.save_pretrained("gptj-config")
275
+ # # new_config = GPTJ.from_pretrained("gptj-config")
276
+ # # model = GPTJ(config)
277
+ # # state_dict = torch.load('model.pt', map_location="cpu")
278
+ # # model.load_state_dict(state_dict)
279
+ # model = GPTJXForCausalLM.from_pretrained("/pretrainedmodel")
280
+ # # model.save_pretrained("/pretrainedmodel")
281
+ # # outputs = model(input_ids, targets)
282
+ # # print(outputs)
283
+ # output = model.generate(input_ids, max_new_tokens=50)
284
+ # print(tokenizer.decode(output[0]))
285
  # print(new_config)