kz919 commited on
Commit
1a7246c
·
verified ·
1 Parent(s): 6b751ef

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +221 -3
README.md CHANGED
@@ -1,3 +1,221 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - amphora/QwQ-LongCoT-130K
5
+ language:
6
+ - en
7
+ metrics:
8
+ - perplexity
9
+ base_model:
10
+ - Qwen/Qwen2.5-0.5B-Instruct
11
+ ---
12
+
13
+ ## Model Details:
14
+
15
+ - **Base Model:** Qwen/Qwen2-0.5B-Instruct
16
+ - **Teacher Model:** Qwen/QwQ-32B-Preview
17
+ - **Distillation Framework:** Generative Knowledge Distillation (GKD)
18
+ - **Task Type:** Conversational AI / Causal Language Modeling
19
+ - **Parameters:** 0.5B
20
+ - **Special Features:**
21
+ - Optimized with LoraConfig for fine-tuning
22
+ - Integrated gradient checkpointing for efficient training
23
+ - Step-by-step reasoning capabilities for better problem-solving
24
+
25
+ ---
26
+
27
+ ## Training:
28
+
29
+ QwQ-0.5B-Distilled was trained using the **QwQ-LongCoT-130K dataset**, a carefully curated collection of long-context examples designed for reasoning and conversational AI tasks. The GKD framework ensures that the student model mimics the teacher model’s outputs, aligning its predictions with high-quality responses.
30
+ ### Training Progress
31
+ [▓░░░░░░░░░░] 10%
32
+
33
+ ### Training Script:
34
+
35
+ ```python
36
+ from datasets import Dataset
37
+ from trl import GKDConfig, GKDTrainer
38
+ from transformers import (
39
+ AutoModelForCausalLM,
40
+ AutoTokenizer,
41
+ )
42
+ from datasets import load_dataset
43
+ from peft import LoraConfig
44
+
45
+ parser = argparse.ArgumentParser()
46
+ parser.add_argument("--temperature", type=float, default = 0.9)
47
+ parser.add_argument("--lmbda", type=float, default = 0.5)
48
+ parser.add_argument("--beta", type=float, default = 0.5)
49
+ parser.add_argument("--max_new_tokens", type=int, default = 4096)
50
+ parser.add_argument("--output_dir", type=str, default="gkd-model")
51
+ parser.add_argument("--per_device_train_batch_size", type=int, default=1)
52
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
53
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
54
+ parser.add_argument("--resume_from_checkpoint", action="store_true", default=False)
55
+ parser.add_argument("--lora", action="store_true")
56
+ args = parser.parse_args()
57
+
58
+ qwq_dataset = load_dataset("amphora/QwQ-LongCoT-130K", split = "train")
59
+ messages = []
60
+ for each in qwq_dataset:
61
+ msg = [
62
+ {"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."},
63
+ {"role": "user", "content": each["problem"]},
64
+ {"role": "assistant", "content": each["qwq"]},
65
+ ]
66
+ messages.append(msg)
67
+
68
+ TRAIN_SPLIT_RATIO = 0.9
69
+ train_size = int(TRAIN_SPLIT_RATIO * len(messages))
70
+ eval_size = len(messages) - train_size
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
73
+
74
+
75
+
76
+
77
+ # The teacher model to calculate the KL divergence against
78
+ teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/QwQ-32B-Preview", torch_dtype=torch.bfloat16, device_map="auto")
79
+ teacher_model.lm_head.weight.data = teacher_model.lm_head.weight.data[:151936, :]
80
+ teacher_model.lm_head.out_features = 151936
81
+
82
+
83
+
84
+ # The model to optimise
85
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")
86
+
87
+
88
+
89
+ ### Real Dataset
90
+ train_dataset = Dataset.from_dict({"messages":messages[:train_size]})
91
+ eval_dataset = Dataset.from_dict({"messages":messages[train_size:]})
92
+ training_args = GKDConfig(
93
+ output_dir=args.output_dir,
94
+ temperature=args.temperature,
95
+ lmbda=args.lmbda,
96
+ beta=args.beta,
97
+ max_new_tokens=args.max_new_tokens,
98
+ per_device_train_batch_size=args.per_device_train_batch_size,
99
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
100
+ gradient_checkpointing = args.gradient_checkpointing,
101
+ save_steps = 100,
102
+ save_total_limit = 5
103
+ )
104
+
105
+ lora_config = LoraConfig(
106
+ r=16,
107
+ lora_alpha=32,
108
+ lora_dropout=0.05,
109
+ bias="none",
110
+ task_type="CAUSAL_LM",
111
+ )
112
+
113
+ trainer = GKDTrainer(
114
+ model=model,
115
+ teacher_model=teacher_model,
116
+ args=training_args,
117
+ processing_class=tokenizer,
118
+ train_dataset=train_dataset,
119
+ eval_dataset=eval_dataset,
120
+ peft_config=lora_config if args.lora else None
121
+ )
122
+ trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
123
+ ```
124
+
125
+ ### Dataset:
126
+ - **Source:** `amphora/QwQ-LongCoT-130K`
127
+ - **Split:** 90% Training, 10% Evaluation
128
+
129
+ ---
130
+
131
+ ## Example Usage:
132
+
133
+ ```python
134
+ import torch
135
+ from transformers import AutoModelForCausalLM, AutoTokenizer
136
+
137
+ # Model name
138
+ model_name = "kz919/QwQ-0.5B-Distilled"
139
+
140
+ # Load the model
141
+ print(f"Starting to load the model {model_name} into memory")
142
+ model = AutoModelForCausalLM.from_pretrained(
143
+ model_name,
144
+ torch_dtype=torch.bfloat16,
145
+ device_map={"": 0}
146
+ )
147
+
148
+ # Load the tokenizer
149
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
150
+
151
+ # Define the prompt
152
+ prompt = "How many r in strawberry."
153
+ messages = [
154
+ {"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."},
155
+ {"role": "user", "content": prompt}
156
+ ]
157
+
158
+ # Tokenize the input
159
+ text = tokenizer.apply_chat_template(
160
+ messages,
161
+ tokenize=False,
162
+ add_generation_prompt=True
163
+ )
164
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
165
+
166
+ # Generate a response
167
+ generated_ids = model.generate(
168
+ **model_inputs,
169
+ max_new_tokens=4096
170
+ )
171
+ generated_ids = [
172
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
173
+ ]
174
+
175
+ # Decode the response
176
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
177
+ print(response)
178
+ ```
179
+
180
+ ---
181
+
182
+ ## Applications:
183
+
184
+ 1. **Conversational Assistants:**
185
+ Suitable for AI chatbots that require reasoning and long-context understanding.
186
+
187
+ 2. **Educational Tools:**
188
+ Provides step-by-step explanations, making it ideal for learning environments.
189
+
190
+ 3. **Creative Writing:**
191
+ Assists in generating coherent, contextually aware long-form content.
192
+
193
+ 4. **Technical Support:**
194
+ Handles complex customer queries with precision and clarity.
195
+
196
+ ---
197
+
198
+ ## Limitations:
199
+
200
+ - While distilled for efficiency, performance on highly complex reasoning tasks may slightly trail the teacher model.
201
+ - Best suited for conversational contexts under the Qwen family’s alignment principles.
202
+
203
+ ---
204
+
205
+ ## Citation:
206
+
207
+ If you use this model in your research or applications, please cite it as:
208
+
209
+ ```bibtex
210
+ @model{qwq_0.5B_distilled,
211
+ author = {Kaizhao Liang},
212
+ title = {QwQ-0.5B-Distilled: A Reasoning Model for Edge Devices},
213
+ year = {2024},
214
+ publisher = {Hugging Face},
215
+ version = {1.0}
216
+ }
217
+ ```
218
+
219
+ ---
220
+
221
+ This model is an example of how efficient fine-tuning and distillation methods can deliver robust conversational AI capabilities in a smaller, more manageable footprint.