sambitchakhf03 commited on
Commit
223a13c
1 Parent(s): 58392dc

Create configuration_RW.py

Browse files
Files changed (1) hide show
  1. configuration_RW.py +64 -0
configuration_RW.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+ from transformers.utils import logging
3
+
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+
8
+ class RWConfig(PretrainedConfig):
9
+ model_type = "RefinedWebModel"
10
+ keys_to_ignore_at_inference = ["past_key_values"]
11
+ attribute_map = {
12
+ "num_hidden_layers": "n_layer",
13
+ "num_attention_heads": "n_head",
14
+ }
15
+
16
+ def __init__(
17
+ self,
18
+ vocab_size=250880,
19
+ hidden_size=64,
20
+ n_layer=2,
21
+ n_head=8,
22
+ layer_norm_epsilon=1e-5,
23
+ initializer_range=0.02,
24
+ use_cache=True,
25
+ bos_token_id=1,
26
+ eos_token_id=2,
27
+ apply_residual_connection_post_layernorm=False,
28
+ hidden_dropout=0.0,
29
+ attention_dropout=0.0,
30
+ multi_query=False,
31
+ alibi=False,
32
+ bias=False,
33
+ parallel_attn=False,
34
+ **kwargs,
35
+ ):
36
+ self.vocab_size = vocab_size
37
+ # Backward compatibility with n_embed kwarg
38
+ n_embed = kwargs.pop("n_embed", None)
39
+ self.hidden_size = hidden_size if n_embed is None else n_embed
40
+ self.n_layer = n_layer
41
+ self.n_head = n_head
42
+ self.layer_norm_epsilon = layer_norm_epsilon
43
+ self.initializer_range = initializer_range
44
+ self.use_cache = use_cache
45
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
46
+ self.hidden_dropout = hidden_dropout
47
+ self.attention_dropout = attention_dropout
48
+
49
+ self.bos_token_id = bos_token_id
50
+ self.eos_token_id = eos_token_id
51
+ self.multi_query = multi_query
52
+ self.alibi = alibi
53
+ self.bias = bias
54
+ self.parallel_attn = parallel_attn
55
+
56
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
57
+
58
+ @property
59
+ def head_dim(self):
60
+ return self.hidden_size // self.n_head
61
+
62
+ @property
63
+ def rotary(self):
64
+ return not self.alibi