m8than commited on
Commit
ef217c5
·
verified ·
1 Parent(s): 25755e9

Upload folder using huggingface_hub

Browse files
__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
6
+ from fla.models.rwkv7.modeling_rwkv7 import RWKV7ForCausalLM, RWKV7Model
7
+
8
+ AutoConfig.register(RWKV7Config.model_type, RWKV7Config)
9
+ AutoModel.register(RWKV7Config, RWKV7Model)
10
+ AutoModelForCausalLM.register(RWKV7Config, RWKV7ForCausalLM)
11
+
12
+
13
+ __all__ = ['RWKV7Config', 'RWKV7ForCausalLM', 'RWKV7Model']
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_attn_implementation_autoset": true,
3
+ "a_low_rank_dim": 96,
4
+ "attn": null,
5
+ "attn_mode": "chunk",
6
+ "bos_token_id": 1,
7
+ "decay_low_rank_dim": 96,
8
+ "eos_token_id": 2,
9
+ "fuse_cross_entropy": true,
10
+ "fuse_norm": true,
11
+ "gate_low_rank_dim": 256,
12
+ "head_dim": 64,
13
+ "hidden_act": "sqrelu",
14
+ "hidden_ratio": 4.0,
15
+ "hidden_size": 2048,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 8192,
18
+ "max_position_embeddings": 2048,
19
+ "model_type": "rwkv7",
20
+ "norm_bias": true,
21
+ "norm_eps": 1e-05,
22
+ "norm_first": true,
23
+ "num_heads": null,
24
+ "num_hidden_layers": 24,
25
+ "tie_word_embeddings": false,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.48.1",
28
+ "use_cache": true,
29
+ "v_low_rank_dim": 64,
30
+ "vocab_size": 65536
31
+ }
configuration_rwkv7.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Dict, Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+
8
+ class RWKV7Config(PretrainedConfig):
9
+
10
+ model_type = 'rwkv7'
11
+ keys_to_ignore_at_inference = ['past_key_values']
12
+
13
+ def __init__(
14
+ self,
15
+ attn_mode: str = "chunk",
16
+ hidden_size: int = 2048,
17
+ hidden_ratio: Optional[int] = 4,
18
+ intermediate_size: Optional[int] = None,
19
+ num_hidden_layers: int = 24,
20
+ head_dim: Optional[int] = 64,
21
+ num_heads: Optional[int] = None,
22
+ decay_low_rank_dim: int = 64,
23
+ gate_low_rank_dim: int = 128,
24
+ a_low_rank_dim: int = 64,
25
+ v_low_rank_dim: int = 16,
26
+ hidden_act: str = "sqrelu",
27
+ max_position_embeddings: int = 2048,
28
+ norm_first: bool = True,
29
+ norm_bias: bool = True,
30
+ norm_eps: float = 1e-5,
31
+ attn: Optional[Dict] = None,
32
+ use_cache: bool = True,
33
+ pad_token_id: int = None,
34
+ bos_token_id: int = 1,
35
+ eos_token_id: int = 2,
36
+ tie_word_embeddings: bool = False,
37
+ initializer_range: float = 0.02,
38
+ fuse_norm: bool = True,
39
+ fuse_cross_entropy: bool = True,
40
+ vocab_size: int = 32000,
41
+ **kwargs
42
+ ):
43
+ self.attn_mode = attn_mode
44
+ self.hidden_size = hidden_size
45
+ self.hidden_ratio = hidden_ratio
46
+ self.intermediate_size = intermediate_size
47
+ self.norm_first = norm_first
48
+ self.num_hidden_layers = num_hidden_layers
49
+ self.head_dim = head_dim
50
+ self.num_heads = num_heads
51
+ self.decay_low_rank_dim = decay_low_rank_dim
52
+ self.gate_low_rank_dim = gate_low_rank_dim
53
+ self.a_low_rank_dim = a_low_rank_dim
54
+ self.v_low_rank_dim = v_low_rank_dim
55
+ self.hidden_act = hidden_act
56
+ self.max_position_embeddings = max_position_embeddings
57
+ self.norm_bias = norm_bias
58
+ self.norm_eps = norm_eps
59
+ self.attn = attn
60
+ self.use_cache = use_cache
61
+ self.initializer_range = initializer_range
62
+ self.fuse_norm = fuse_norm
63
+ self.fuse_cross_entropy = fuse_cross_entropy
64
+ self.vocab_size = vocab_size
65
+
66
+ if attn is not None:
67
+ if not isinstance(attn, Dict):
68
+ raise ValueError("attn must be a dictionary")
69
+ if 'layers' not in attn:
70
+ raise ValueError("Layer indices must be provided to initialize hybrid attention layers")
71
+ if 'num_heads' not in attn:
72
+ raise ValueError("Number of heads must be provided to initialize hybrid attention layers")
73
+ attn['num_kv_heads'] = attn.get('num_kv_heads', attn['num_heads'])
74
+ attn['window_size'] = attn.get('window_size', None)
75
+ attn['rope_theta'] = attn.get('rope_theta', 10000.)
76
+
77
+ super().__init__(
78
+ pad_token_id=pad_token_id,
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ tie_word_embeddings=tie_word_embeddings,
82
+ **kwargs,
83
+ )
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0bf1fb5bf5cb90e0b401164cb5230eb28b66c517ce11ab8d510faa73aeefc63f
3
+ size 6109691400
modeling_rwkv7.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.generation import GenerationMixin
13
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
14
+ CausalLMOutputWithPast)
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import logging
17
+
18
+ from fla.layers.attn import Attention
19
+ from fla.layers.rwkv7 import RWKV7Attention
20
+ from fla.models.rwkv7.configuration_rwkv7 import RWKV7Config
21
+ from fla.models.utils import Cache
22
+ from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
23
+ LayerNorm)
24
+ from fla.modules.activations import ACT2FN
25
+
26
+ if TYPE_CHECKING:
27
+ from transformers.processing_utils import Unpack
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ class RWKV7FeedForward(nn.Module):
33
+
34
+ def __init__(
35
+ self,
36
+ hidden_size: int,
37
+ hidden_ratio: Optional[int] = None,
38
+ intermediate_size: Optional[int] = None,
39
+ hidden_act: str = 'sqrelu',
40
+ layer_idx: int = None
41
+ ) -> RWKV7FeedForward:
42
+ super().__init__()
43
+
44
+ self.hidden_size = hidden_size
45
+ if hidden_ratio is None:
46
+ hidden_ratio = 4
47
+ if intermediate_size is None:
48
+ intermediate_size = int(hidden_size * hidden_ratio)
49
+ intermediate_size = 32 * ((intermediate_size + 32 - 1) // 32)
50
+ self.hidden_ratio = hidden_ratio
51
+ self.intermediate_size = intermediate_size
52
+
53
+ self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
54
+
55
+ self.x_k = nn.Parameter(torch.zeros(hidden_size))
56
+
57
+ self.key = nn.Linear(hidden_size, intermediate_size, bias=False)
58
+ self.value = nn.Linear(intermediate_size, hidden_size, bias=False)
59
+ self.act_fn = ACT2FN[hidden_act]
60
+
61
+ self.layer_idx = layer_idx
62
+
63
+ def forward(
64
+ self,
65
+ x: torch.Tensor,
66
+ attention_mask: Optional[torch.Tensor] = None,
67
+ state: Optional[Cache] = None
68
+ ) -> torch.Tensor:
69
+ if attention_mask is not None:
70
+ x = x.mul(attention_mask[:, -x.shape[-2]:, None])
71
+ if x.shape[1] == 1 and state is not None:
72
+ shifted = state[self.layer_idx]['ffn_state'].unsqueeze(1)
73
+ else:
74
+ shifted = self.time_shift(x)
75
+ if state is not None and state[self.layer_idx]['ffn_state'] is not None:
76
+ shifted[:, 0] = state[self.layer_idx]['ffn_state'][-1]
77
+ if state is not None:
78
+ # no need to update the offset twice
79
+ state.update(ffn_state=x[:, -1], layer_idx=self.layer_idx, offset=0)
80
+ return self.value(self.act_fn(self.key(x + (shifted - x) * self.x_k))), state
81
+
82
+
83
+ class RWKV7Block(nn.Module):
84
+
85
+ def __init__(
86
+ self,
87
+ config: RWKV7Config,
88
+ layer_idx: int
89
+ ) -> RWKV7Block:
90
+ super().__init__()
91
+ self.hidden_size = config.hidden_size
92
+
93
+ self.config = config
94
+ self.layer_idx = layer_idx
95
+
96
+ if config.norm_first and layer_idx == 0:
97
+ self.pre_norm = LayerNorm(hidden_size=config.hidden_size, bias=config.norm_bias, eps=config.norm_eps)
98
+ self.attn_norm = LayerNorm(hidden_size=config.hidden_size, bias=config.norm_bias, eps=config.norm_eps)
99
+ if config.attn is not None and layer_idx in config.attn['layers']:
100
+ self.attn = Attention(
101
+ hidden_size=config.hidden_size,
102
+ num_heads=config.attn['num_heads'],
103
+ num_kv_heads=config.attn['num_kv_heads'],
104
+ window_size=config.attn['window_size'],
105
+ rope_theta=config.attn['rope_theta'],
106
+ max_position_embeddings=config.max_position_embeddings,
107
+ layer_idx=layer_idx
108
+ )
109
+ else:
110
+ self.attn = RWKV7Attention(
111
+ mode=config.attn_mode,
112
+ hidden_size=config.hidden_size,
113
+ head_dim=config.head_dim,
114
+ num_heads=config.num_heads,
115
+ decay_low_rank_dim=config.decay_low_rank_dim,
116
+ gate_low_rank_dim=config.gate_low_rank_dim,
117
+ a_low_rank_dim=config.a_low_rank_dim,
118
+ v_low_rank_dim=config.v_low_rank_dim,
119
+ norm_eps=config.norm_eps,
120
+ fuse_norm=config.fuse_norm,
121
+ layer_idx=layer_idx
122
+ )
123
+ self.ffn_norm = LayerNorm(hidden_size=config.hidden_size, bias=config.norm_bias, eps=config.norm_eps)
124
+ self.ffn = RWKV7FeedForward(
125
+ hidden_size=config.hidden_size,
126
+ hidden_ratio=config.hidden_ratio,
127
+ intermediate_size=config.intermediate_size,
128
+ hidden_act=config.hidden_act,
129
+ layer_idx=layer_idx
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ hidden_states: torch.Tensor,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ past_key_values: Optional[Cache] = None,
137
+ use_cache: Optional[bool] = False,
138
+ output_attentions: Optional[bool] = False,
139
+ v_first: torch.Tensor = None,
140
+ **kwargs,
141
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
142
+ residual = self.pre_norm(hidden_states) if hasattr(self, 'pre_norm') else hidden_states
143
+ hidden_states = self.attn_norm(residual)
144
+ hidden_states, attentions, past_key_values, v_first = self.attn(
145
+ hidden_states=hidden_states,
146
+ attention_mask=attention_mask,
147
+ past_key_values=past_key_values,
148
+ use_cache=use_cache,
149
+ output_attentions=output_attentions,
150
+ v_first=v_first,
151
+ **kwargs
152
+ )
153
+ hidden_states, residual = self.ffn_norm(hidden_states, residual, True)
154
+ hidden_states, past_key_values = self.ffn(hidden_states, attention_mask, past_key_values)
155
+ hidden_states = residual + hidden_states
156
+
157
+ outputs = (hidden_states, attentions, past_key_values, v_first)
158
+
159
+ return outputs
160
+
161
+
162
+ class RWKV7PreTrainedModel(PreTrainedModel):
163
+
164
+ config_class = RWKV7Config
165
+ base_model_prefix = 'model'
166
+ supports_gradient_checkpointing = True
167
+ _no_split_modules = ['RWKV7Block']
168
+ _supports_cache_class = True
169
+
170
+ def __init__(self, *inputs, **kwargs):
171
+ super().__init__(*inputs, **kwargs)
172
+
173
+ def _init_weights(
174
+ self,
175
+ module: nn.Module,
176
+ rescale_prenorm_residual: bool = True,
177
+ num_residuals_per_layer: int = 2,
178
+ ):
179
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
180
+ # Slightly different from the TF version which uses truncated_normal for initialization
181
+ # cf https://github.com/pytorch/pytorch/pull/5617
182
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
183
+ if module.bias is not None:
184
+ nn.init.zeros_(module.bias)
185
+ elif isinstance(module, nn.Parameter):
186
+ nn.init.normal_(module, mean=0.0, std=self.config.initializer_range)
187
+ elif isinstance(module, nn.Embedding):
188
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
189
+ if module.padding_idx is not None:
190
+ module.weight.data[module.padding_idx].zero_()
191
+ elif hasattr(module, 'reset_parameters'):
192
+ module.reset_parameters()
193
+
194
+ if rescale_prenorm_residual:
195
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
196
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
197
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
198
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
199
+ #
200
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
201
+ for name, p in module.named_parameters():
202
+ if name in ["o_proj.weight", "down_proj.weight"]:
203
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
204
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
205
+ # We need to reinit p since this code could be called multiple times
206
+ # Having just p *= scale would repeatedly scale it down
207
+ with torch.no_grad():
208
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
209
+
210
+
211
+ class RWKV7Model(RWKV7PreTrainedModel):
212
+
213
+ def __init__(self, config: RWKV7Config):
214
+ super().__init__(config)
215
+ self.padding_idx = config.pad_token_id
216
+ self.vocab_size = config.vocab_size
217
+
218
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
219
+ self.layers = nn.ModuleList([RWKV7Block(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
220
+ self.norm = LayerNorm(config.hidden_size, bias=config.norm_bias, eps=config.norm_eps)
221
+
222
+ self.gradient_checkpointing = False
223
+
224
+ self.post_init()
225
+
226
+ def get_input_embeddings(self):
227
+ return self.embeddings
228
+
229
+ def set_input_embeddings(self, value):
230
+ self.embeddings = value
231
+
232
+ def forward(
233
+ self,
234
+ input_ids: Optional[torch.LongTensor] = None,
235
+ attention_mask: Optional[torch.Tensor] = None, # noqa
236
+ inputs_embeds: Optional[torch.FloatTensor] = None,
237
+ past_key_values: Optional[Cache] = None,
238
+ use_cache: Optional[bool] = None,
239
+ output_attentions: Optional[bool] = None,
240
+ output_hidden_states: Optional[bool] = None,
241
+ return_dict: Optional[bool] = None,
242
+ **kwargs: Unpack[Dict]
243
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
244
+ if output_attentions:
245
+ warnings.warn("`RWKV7Model` does not `output_attentions` now, setting it to `False`.")
246
+ output_attentions = False
247
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
248
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
249
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
250
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
251
+
252
+ # retrieve input_ids and inputs_embeds
253
+ if input_ids is not None and inputs_embeds is not None:
254
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
255
+ if input_ids is None and inputs_embeds is None:
256
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
257
+
258
+ if inputs_embeds is None:
259
+ inputs_embeds = self.embeddings(input_ids)
260
+ hidden_states = inputs_embeds
261
+
262
+ if use_cache and not isinstance(past_key_values, Cache):
263
+ past_key_values = Cache.from_legacy_cache(past_key_values)
264
+
265
+ if self.gradient_checkpointing and self.training and use_cache:
266
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
267
+ use_cache = False
268
+
269
+ all_hidden_states = () if output_hidden_states else None
270
+ all_attns = () if output_attentions else None
271
+
272
+ v_first = torch.zeros_like(hidden_states)
273
+ for layer in self.layers:
274
+ if output_hidden_states:
275
+ all_hidden_states += (hidden_states,)
276
+
277
+ if self.gradient_checkpointing and self.training:
278
+ hidden_states, attentions, past_key_values, v_first = self._gradient_checkpointing_func(
279
+ layer.__call__,
280
+ hidden_states,
281
+ attention_mask,
282
+ past_key_values,
283
+ use_cache,
284
+ output_attentions,
285
+ v_first,
286
+ **kwargs
287
+ )
288
+ else:
289
+ hidden_states, attentions, past_key_values, v_first = layer(
290
+ hidden_states,
291
+ attention_mask=attention_mask,
292
+ past_key_values=past_key_values,
293
+ use_cache=use_cache,
294
+ output_attentions=output_attentions,
295
+ v_first=v_first,
296
+ **kwargs
297
+ )
298
+
299
+ if output_attentions:
300
+ all_attns += (attentions,)
301
+
302
+ hidden_states = self.norm(hidden_states)
303
+
304
+ # add hidden states from the last decoder layer
305
+ if output_hidden_states:
306
+ all_hidden_states += (hidden_states,)
307
+
308
+ if not return_dict:
309
+ return tuple(i for i in [hidden_states, past_key_values, all_hidden_states, all_attns] if i is not None)
310
+ return BaseModelOutputWithPast(
311
+ last_hidden_state=hidden_states,
312
+ past_key_values=past_key_values,
313
+ hidden_states=all_hidden_states,
314
+ attentions=all_attns
315
+ )
316
+
317
+
318
+ class RWKV7ForCausalLM(RWKV7PreTrainedModel, GenerationMixin):
319
+
320
+ _tied_weights_keys = ["lm_head.weight"]
321
+
322
+ def __init__(self, config):
323
+ super().__init__(config)
324
+ self.model = RWKV7Model(config)
325
+ self.vocab_size = config.vocab_size
326
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
327
+
328
+ # Initialize weights and apply final processing
329
+ self.post_init()
330
+
331
+ def get_input_embeddings(self):
332
+ return self.model.embeddings
333
+
334
+ def set_input_embeddings(self, value):
335
+ self.model.embeddings = value
336
+
337
+ def get_output_embeddings(self):
338
+ return self.lm_head
339
+
340
+ def set_output_embeddings(self, new_embeddings):
341
+ self.lm_head = new_embeddings
342
+
343
+ def set_decoder(self, decoder):
344
+ self.model = decoder
345
+
346
+ def get_decoder(self):
347
+ return self.model
348
+
349
+ def generate(self, *args, **kwargs):
350
+ try:
351
+ return super().generate(*args, **kwargs)
352
+ except AttributeError as exception:
353
+ if 'past_key_values' in str(exception):
354
+ raise AttributeError(
355
+ f"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`, "
356
+ f"which is not supported for {self.__class__.__name__}. "
357
+ f"Try another generation strategy instead. "
358
+ f"For the available generation strategies, check this doc: "
359
+ f"https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
360
+ )
361
+ else:
362
+ raise exception
363
+
364
+ def prepare_inputs_for_generation(
365
+ self,
366
+ input_ids: torch.LongTensor = None,
367
+ past_key_values: Optional[Cache] = None,
368
+ attention_mask: Optional[torch.Tensor] = None,
369
+ inputs_embeds: Optional[torch.Tensor] = None,
370
+ use_cache: bool = True,
371
+ num_logits_to_keep: Optional[int] = None,
372
+ **kwargs
373
+ ):
374
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
375
+ if past_key_values is not None:
376
+ input_ids = input_ids[:, -1:]
377
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
378
+ if inputs_embeds is not None and past_key_values is None:
379
+ model_inputs = {'inputs_embeds': inputs_embeds}
380
+ else:
381
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
382
+ # recompiles graphs as the stride of the inputs is a guard.
383
+ # Ref: https://github.com/huggingface/transformers/pull/29114
384
+ # TODO: use `next_tokens` directly instead.
385
+ model_inputs = {'input_ids': input_ids.contiguous()}
386
+
387
+ if num_logits_to_keep is not None:
388
+ model_inputs['num_logits_to_keep'] = num_logits_to_keep
389
+
390
+ model_inputs.update({
391
+ 'past_key_values': past_key_values,
392
+ 'use_cache': use_cache,
393
+ 'attention_mask': attention_mask,
394
+ 'num_logits_to_keep': num_logits_to_keep,
395
+ })
396
+ return model_inputs
397
+
398
+ def forward(
399
+ self,
400
+ input_ids: torch.LongTensor = None,
401
+ attention_mask: Optional[torch.Tensor] = None,
402
+ inputs_embeds: Optional[torch.Tensor] = None,
403
+ past_key_values: Optional[Cache] = None,
404
+ labels: Optional[torch.LongTensor] = None,
405
+ use_cache: Optional[bool] = None,
406
+ output_attentions: Optional[bool] = None,
407
+ output_hidden_states: Optional[bool] = None,
408
+ return_dict: Optional[bool] = None,
409
+ num_logits_to_keep: Optional[int] = 0,
410
+ **kwargs: Unpack[Dict]
411
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
412
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
413
+ output_hidden_states = (
414
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
415
+ )
416
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
417
+
418
+ outputs = self.model(
419
+ input_ids=input_ids,
420
+ attention_mask=attention_mask,
421
+ inputs_embeds=inputs_embeds,
422
+ past_key_values=past_key_values,
423
+ use_cache=use_cache,
424
+ output_attentions=output_attentions,
425
+ output_hidden_states=output_hidden_states,
426
+ return_dict=return_dict,
427
+ **kwargs
428
+ )
429
+
430
+ hidden_states = outputs[0]
431
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
432
+
433
+ loss, logits = None, None
434
+ if not fuse_linear_and_cross_entropy or labels is None:
435
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:])
436
+ if labels is not None:
437
+ if self.config.fuse_cross_entropy:
438
+ if fuse_linear_and_cross_entropy:
439
+ loss_fct = FusedLinearCrossEntropyLoss()
440
+ else:
441
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
442
+ else:
443
+ loss_fct = nn.CrossEntropyLoss()
444
+ # Enable model parallelism
445
+ labels = labels.to(hidden_states.device)
446
+ labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
447
+ if fuse_linear_and_cross_entropy:
448
+ loss = loss_fct(hidden_states.view(-1, self.config.hidden_size),
449
+ labels.view(-1),
450
+ self.lm_head.weight,
451
+ self.lm_head.bias)
452
+ else:
453
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
454
+
455
+ if not return_dict:
456
+ output = (logits,) + outputs[1:]
457
+ return (loss,) + output if loss is not None else output
458
+
459
+ return CausalLMOutputWithPast(
460
+ loss=loss,
461
+ logits=logits,
462
+ past_key_values=outputs.past_key_values,
463
+ hidden_states=outputs.hidden_states,
464
+ attentions=outputs.attentions,
465
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": false,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<|endoftext|>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<|padding|>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "50254": {
23
+ "content": " ",
24
+ "lstrip": false,
25
+ "normalized": true,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": false
29
+ },
30
+ "50255": {
31
+ "content": " ",
32
+ "lstrip": false,
33
+ "normalized": true,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": false
37
+ },
38
+ "50256": {
39
+ "content": " ",
40
+ "lstrip": false,
41
+ "normalized": true,
42
+ "rstrip": false,
43
+ "single_word": false,
44
+ "special": false
45
+ },
46
+ "50257": {
47
+ "content": " ",
48
+ "lstrip": false,
49
+ "normalized": true,
50
+ "rstrip": false,
51
+ "single_word": false,
52
+ "special": false
53
+ },
54
+ "50258": {
55
+ "content": " ",
56
+ "lstrip": false,
57
+ "normalized": true,
58
+ "rstrip": false,
59
+ "single_word": false,
60
+ "special": false
61
+ },
62
+ "50259": {
63
+ "content": " ",
64
+ "lstrip": false,
65
+ "normalized": true,
66
+ "rstrip": false,
67
+ "single_word": false,
68
+ "special": false
69
+ },
70
+ "50260": {
71
+ "content": " ",
72
+ "lstrip": false,
73
+ "normalized": true,
74
+ "rstrip": false,
75
+ "single_word": false,
76
+ "special": false
77
+ },
78
+ "50261": {
79
+ "content": " ",
80
+ "lstrip": false,
81
+ "normalized": true,
82
+ "rstrip": false,
83
+ "single_word": false,
84
+ "special": false
85
+ },
86
+ "50262": {
87
+ "content": " ",
88
+ "lstrip": false,
89
+ "normalized": true,
90
+ "rstrip": false,
91
+ "single_word": false,
92
+ "special": false
93
+ },
94
+ "50263": {
95
+ "content": " ",
96
+ "lstrip": false,
97
+ "normalized": true,
98
+ "rstrip": false,
99
+ "single_word": false,
100
+ "special": false
101
+ },
102
+ "50264": {
103
+ "content": " ",
104
+ "lstrip": false,
105
+ "normalized": true,
106
+ "rstrip": false,
107
+ "single_word": false,
108
+ "special": false
109
+ },
110
+ "50265": {
111
+ "content": " ",
112
+ "lstrip": false,
113
+ "normalized": true,
114
+ "rstrip": false,
115
+ "single_word": false,
116
+ "special": false
117
+ },
118
+ "50266": {
119
+ "content": " ",
120
+ "lstrip": false,
121
+ "normalized": true,
122
+ "rstrip": false,
123
+ "single_word": false,
124
+ "special": false
125
+ },
126
+ "50267": {
127
+ "content": " ",
128
+ "lstrip": false,
129
+ "normalized": true,
130
+ "rstrip": false,
131
+ "single_word": false,
132
+ "special": false
133
+ },
134
+ "50268": {
135
+ "content": " ",
136
+ "lstrip": false,
137
+ "normalized": true,
138
+ "rstrip": false,
139
+ "single_word": false,
140
+ "special": false
141
+ },
142
+ "50269": {
143
+ "content": " ",
144
+ "lstrip": false,
145
+ "normalized": true,
146
+ "rstrip": false,
147
+ "single_word": false,
148
+ "special": false
149
+ },
150
+ "50270": {
151
+ "content": " ",
152
+ "lstrip": false,
153
+ "normalized": true,
154
+ "rstrip": false,
155
+ "single_word": false,
156
+ "special": false
157
+ },
158
+ "50271": {
159
+ "content": " ",
160
+ "lstrip": false,
161
+ "normalized": true,
162
+ "rstrip": false,
163
+ "single_word": false,
164
+ "special": false
165
+ },
166
+ "50272": {
167
+ "content": " ",
168
+ "lstrip": false,
169
+ "normalized": true,
170
+ "rstrip": false,
171
+ "single_word": false,
172
+ "special": false
173
+ },
174
+ "50273": {
175
+ "content": " ",
176
+ "lstrip": false,
177
+ "normalized": true,
178
+ "rstrip": false,
179
+ "single_word": false,
180
+ "special": false
181
+ },
182
+ "50274": {
183
+ "content": " ",
184
+ "lstrip": false,
185
+ "normalized": true,
186
+ "rstrip": false,
187
+ "single_word": false,
188
+ "special": false
189
+ },
190
+ "50275": {
191
+ "content": " ",
192
+ "lstrip": false,
193
+ "normalized": true,
194
+ "rstrip": false,
195
+ "single_word": false,
196
+ "special": false
197
+ },
198
+ "50276": {
199
+ "content": " ",
200
+ "lstrip": false,
201
+ "normalized": true,
202
+ "rstrip": false,
203
+ "single_word": false,
204
+ "special": false
205
+ }
206
+ },
207
+ "bos_token": "<|endoftext|>",
208
+ "clean_up_tokenization_spaces": false,
209
+ "eos_token": "<|endoftext|>",
210
+ "model_max_length": 1000000000000000019884624838656,
211
+ "pad_token": null,
212
+ "tokenizer_class": "GPTNeoXTokenizer",
213
+ "unk_token": "<|endoftext|>"
214
+ }