erfanzar commited on
Commit
bf98250
·
verified ·
1 Parent(s): ffea312

Upload Qwen2ForSequenceClassification

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. README.md +80 -0
  3. config.json +97 -0
  4. easydel-model.parameters +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ easydel-model.parameters filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ---
3
+ tags:
4
+ - EasyDeL
5
+ - Qwen2ForSequenceClassification
6
+ - safetensors
7
+ - TPU
8
+ - GPU
9
+ - XLA
10
+ - Flax
11
+ ---
12
+ # EasyDeL/Qwen2-0.5B-RewardModel
13
+
14
+ [![EasyDeL](https://img.shields.io/badge/🤗_EasyDeL-0.1.0-blue.svg)](https://github.com/erfanzar/EasyDeL)
15
+ [![Model Type](https://img.shields.io/badge/Model_Type-Qwen2ForSequenceClassification-green.svg)](https://github.com/erfanzar/EasyDeL)
16
+
17
+ A model implemented using the EasyDeL framework, designed to deliver optimal performance for large-scale natural language processing tasks.
18
+
19
+ ## Overview
20
+
21
+ EasyDeL provides an efficient, highly-optimized, and customizable machine learning model compatible with both GPU and TPU environments. Built with JAX, this model supports advanced features such as sharded model parallelism, making it suitable for distributed training and inference and customized kernels.
22
+
23
+ ## Features
24
+
25
+
26
+ - **Efficient Implementation**: Built with JAX/Flax for high-performance computation.
27
+ - **Multi-Device Support**: Optimized to run on TPU, GPU, and CPU environments for sharding model over 2^(1-1000+) of devices.
28
+ - **Sharded Model Parallelism**: Supports model parallelism across multiple devices for scalability.
29
+ - **Customizable Precision**: Allows specification of floating-point precision for performance optimization.
30
+
31
+
32
+ ## Installation
33
+
34
+ To install EasyDeL, simply run:
35
+
36
+ ```bash
37
+ pip install easydel
38
+ ```
39
+
40
+ ## Usage
41
+
42
+ ### Loading the Pre-trained Model
43
+
44
+ To load a pre-trained version of the model with EasyDeL:
45
+
46
+ ```python
47
+ from easydel import AutoEasyDeLModelForCausalLM
48
+ from jax import numpy as jnp, lax
49
+
50
+ max_length = None # can be set to use lower memory for caching
51
+
52
+ # Load model and parameters
53
+ model = AutoEasyDeLModelForCausalLM.from_pretrained(
54
+ "EasyDeL/Qwen2-0.5B-RewardModel",
55
+ config_kwargs=ed.EasyDeLBaseConfigDict(
56
+ use_scan_mlp=False,
57
+ attn_dtype=jnp.float16,
58
+ freq_max_position_embeddings=max_length,
59
+ mask_max_position_embeddings=max_length,
60
+ attn_mechanism=ed.AttentionMechanisms.FLASH_ATTN2
61
+ ),
62
+ dtype=jnp.float16,
63
+ param_dtype=jnp.float16,
64
+ precision=lax.Precision("fastest"),
65
+ auto_shard_model=True,
66
+ )
67
+ ```
68
+
69
+ ## Supported Tasks
70
+
71
+
72
+ [Need more information]
73
+
74
+
75
+ ## Limitations
76
+
77
+
78
+ - **Hardware Dependency**: Performance can vary significantly based on the hardware used.
79
+ - **JAX/Flax Setup Required**: The environment must support JAX/Flax for optimal use.
80
+ - **Experimental Features**: Some features (like custom kernel usage or ed-ops) may require additional configuration and tuning.
config.json ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen2ForSequenceClassification"
4
+ ],
5
+ "attention_axis_name": "sp",
6
+ "attention_dropout": 0.0,
7
+ "attn_mechanism": "vanilla",
8
+ "axis_dims": [
9
+ 1,
10
+ 1,
11
+ 1,
12
+ -1
13
+ ],
14
+ "axis_names": [
15
+ "dp",
16
+ "fsdp",
17
+ "tp",
18
+ "sp"
19
+ ],
20
+ "backend": null,
21
+ "bits": null,
22
+ "blocksize_b": 1,
23
+ "blocksize_k": 128,
24
+ "blocksize_q": 128,
25
+ "bos_token_id": 151643,
26
+ "easy_method": "train",
27
+ "embd_pdrop": 0.0,
28
+ "eos_token_id": 151645,
29
+ "fcm_max_ratio": 0.0,
30
+ "fcm_min_ratio": 0.0,
31
+ "flash_attention_backward_pass_impl": "triton",
32
+ "freq_max_position_embeddings": 2048,
33
+ "gradient_checkpointing": "nothing_saveable",
34
+ "hardware_abstraction": false,
35
+ "head_dim": 64,
36
+ "hidden_act": "silu",
37
+ "hidden_size": 896,
38
+ "initializer_range": 0.02,
39
+ "intermediate_size": 4864,
40
+ "kv_cache_quantization_blocksize": 64,
41
+ "kv_cache_quantization_method": "None",
42
+ "kv_cache_sharding_sequence_axis_name": "sp",
43
+ "mask_max_position_embeddings": 2048,
44
+ "max_position_embeddings": 32768,
45
+ "max_window_layers": 24,
46
+ "model_type": "qwen2",
47
+ "num_attention_heads": 14,
48
+ "num_hidden_layers": 24,
49
+ "num_key_value_heads": 2,
50
+ "number_rep_kv": 1,
51
+ "pad_token_id": 151643,
52
+ "pallas_k_block_size": null,
53
+ "pallas_m_block_size": null,
54
+ "pallas_n_block_size": null,
55
+ "partition_axis": [
56
+ [
57
+ "fsdp",
58
+ "dp"
59
+ ],
60
+ "sp",
61
+ "sp",
62
+ "tp",
63
+ "sp",
64
+ "tp",
65
+ null,
66
+ null,
67
+ null,
68
+ null,
69
+ "tp",
70
+ "sp",
71
+ null
72
+ ],
73
+ "platform": "jax",
74
+ "pretraining_tp": 1,
75
+ "quantization_blocksize": 64,
76
+ "quantization_method": "None",
77
+ "quantization_pattern": ".*",
78
+ "resid_pdrop": 0.0,
79
+ "rms_norm_eps": 1e-06,
80
+ "rope_scaling": null,
81
+ "rope_theta": 10000.0,
82
+ "scan_attention_layers": false,
83
+ "scan_layers": true,
84
+ "scan_mlp_chunk_size": 1024,
85
+ "scan_ring_attention": true,
86
+ "shard_attention_computation": true,
87
+ "sliding_window": 32768,
88
+ "tie_word_embeddings": false,
89
+ "torch_dtype": "bfloat16",
90
+ "transformers_version": "4.47.1",
91
+ "use_cache": true,
92
+ "use_scan_mlp": false,
93
+ "use_sharded_kv_caching": false,
94
+ "use_sharding_constraint": false,
95
+ "use_sliding_window": false,
96
+ "vocab_size": 151936
97
+ }
easydel-model.parameters ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8c9dc3bd160d63d69b598bc326c7c637dbe09876bd5587c1150ffd1c2592ff3
3
+ size 988101464