Text Generation
Transformers
Safetensors
English
qwen2
conversational
text-generation-inference
Inference Endpoints
kz919 commited on
Commit
adcac77
Β·
verified Β·
1 Parent(s): bc7e75e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +202 -3
README.md CHANGED
@@ -1,3 +1,202 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - amphora/QwQ-LongCoT-130K
5
+ language:
6
+ - en
7
+ metrics:
8
+ - perplexity
9
+ base_model:
10
+ - Qwen/Qwen2.5-0.5B-Instruct
11
+ ---
12
+ ## Model Details:
13
+
14
+ - **Base Model:** Qwen/Qwen2-0.5B-Instruct
15
+ - **Teacher Model:** Qwen/QwQ-32B-Preview
16
+ - **Distillation Framework:** Instruction Tuning
17
+ - **Task Type:** Conversational AI / Causal Language Modeling
18
+ - **Parameters:** 0.5B
19
+ - **Special Features:**
20
+ - Integrated gradient checkpointing for efficient training
21
+ - Step-by-step reasoning capabilities for better problem-solving
22
+
23
+ ---
24
+
25
+ ## Training:
26
+
27
+ QwQ-0.5B-Distilled was trained using the **QwQ-LongCoT-130K dataset**, a carefully curated collection of long-context examples designed for reasoning and conversational AI tasks. The GKD framework ensures that the student model mimics the teacher model’s outputs, aligning its predictions with high-quality responses.
28
+ ### Training Progress:
29
+ [β–“β–“β–“β–“β–“β–“β–“β–“β–“β–“] 100%
30
+
31
+ ### Training Script:
32
+
33
+ ```python
34
+ import os
35
+ import argparse
36
+ import torch
37
+ from datasets import Dataset
38
+ from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
39
+ from transformers import (
40
+ AutoModelForCausalLM,
41
+ AutoTokenizer,
42
+ )
43
+ from datasets import load_dataset
44
+ from peft import LoraConfig
45
+
46
+ parser = argparse.ArgumentParser()
47
+ parser.add_argument("--max_length", type=int, default = 4096)
48
+ parser.add_argument("--output_dir", type=str, default="gkd-model")
49
+ parser.add_argument("--per_device_train_batch_size", type=int, default=1)
50
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=16)
51
+ parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
52
+ parser.add_argument("--resume_from_checkpoint", action="store_true", default=False)
53
+ parser.add_argument("--lora", action="store_true")
54
+ args = parser.parse_args()
55
+
56
+ qwq_dataset = load_dataset("amphora/QwQ-LongCoT-130K", split = "train")
57
+ messages = []
58
+ for each in qwq_dataset:
59
+ msg = [
60
+ {"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."},
61
+ {"role": "user", "content": each["problem"]},
62
+ {"role": "assistant", "content": each["qwq"]},
63
+ ]
64
+ messages.append(msg)
65
+
66
+ TRAIN_SPLIT_RATIO = 0.9
67
+ train_size = int(TRAIN_SPLIT_RATIO * len(messages))
68
+ eval_size = len(messages) - train_size
69
+
70
+ tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
71
+
72
+ # The model to optimise
73
+ model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", torch_dtype=torch.bfloat16, device_map="auto")
74
+
75
+
76
+
77
+ ### Real Dataset
78
+ train_dataset = Dataset.from_dict({"messages":messages[:train_size]})
79
+ eval_dataset = Dataset.from_dict({"messages":messages[train_size:]})
80
+ training_args = SFTConfig(
81
+ output_dir=args.output_dir,
82
+ max_seq_length=args.max_length,
83
+ per_device_train_batch_size=args.per_device_train_batch_size,
84
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
85
+ gradient_checkpointing = args.gradient_checkpointing,
86
+ save_steps = 100,
87
+ save_total_limit = 5
88
+ )
89
+
90
+ lora_config = LoraConfig(
91
+ r=16,
92
+ lora_alpha=32,
93
+ lora_dropout=0.05,
94
+ bias="none",
95
+ task_type="CAUSAL_LM",
96
+ )
97
+
98
+ response_template = "<|im_start|>assistant\n"
99
+
100
+ collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
101
+
102
+ trainer = SFTTrainer(
103
+ model=model,
104
+ args=training_args,
105
+ processing_class=tokenizer,
106
+ train_dataset=train_dataset,
107
+ eval_dataset=eval_dataset,
108
+ peft_config=lora_config if args.lora else None,
109
+ data_collator=collator,
110
+ )
111
+ trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
112
+ ```
113
+
114
+ ### Dataset:
115
+ - **Source:** `amphora/QwQ-LongCoT-130K`
116
+ - **Split:** 90% Training, 10% Evaluation
117
+
118
+ ---
119
+
120
+ ## Example Usage:
121
+
122
+ ```python
123
+ import torch
124
+ from transformers import AutoModelForCausalLM, AutoTokenizer
125
+ # Model name
126
+ model_name = "kz919/QwQ-0.5B-Distilled-SFT"
127
+ # Load the model
128
+ print(f"Starting to load the model {model_name} into memory")
129
+ model = AutoModelForCausalLM.from_pretrained(
130
+ model_name,
131
+ torch_dtype=torch.bfloat16,
132
+ device_map={"": 0}
133
+ )
134
+ # Load the tokenizer
135
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
136
+ # Define the prompt
137
+ prompt = "How many r in strawberry."
138
+ messages = [
139
+ {"role": "system", "content": "You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step."},
140
+ {"role": "user", "content": prompt}
141
+ ]
142
+ # Tokenize the input
143
+ text = tokenizer.apply_chat_template(
144
+ messages,
145
+ tokenize=False,
146
+ add_generation_prompt=True
147
+ )
148
+ model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
149
+ # Generate a response
150
+ generated_ids = model.generate(
151
+ **model_inputs,
152
+ max_new_tokens=4096
153
+ )
154
+ generated_ids = [
155
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
156
+ ]
157
+ # Decode the response
158
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
159
+ print(response)
160
+ ```
161
+
162
+ ---
163
+
164
+ ## Applications:
165
+
166
+ 1. **Conversational Assistants:**
167
+ Suitable for AI chatbots that require reasoning and long-context understanding.
168
+
169
+ 2. **Educational Tools:**
170
+ Provides step-by-step explanations, making it ideal for learning environments.
171
+
172
+ 3. **Creative Writing:**
173
+ Assists in generating coherent, contextually aware long-form content.
174
+
175
+ 4. **Technical Support:**
176
+ Handles complex customer queries with precision and clarity.
177
+
178
+ ---
179
+
180
+ ## Limitations:
181
+
182
+ - While distilled for efficiency, performance on highly complex reasoning tasks may slightly trail the teacher model.
183
+ - Warning 🚨🚨🚨: This model is not fully trained, merely a proof of concept. Don't yell at me if it's outputing nonesense.
184
+ ---
185
+
186
+ ## Citation:
187
+
188
+ If you use this model in your research or applications, please cite it as:
189
+
190
+ ```bibtex
191
+ @model{qwq_0.5B_distilled,
192
+ author = {Kaizhao Liang},
193
+ title = {QwQ-0.5B-Distilled: A Reasoning Model for Edge Devices},
194
+ year = {2024},
195
+ publisher = {Hugging Face},
196
+ version = {1.0}
197
+ }
198
+ ```
199
+
200
+ ---
201
+
202
+ This model is an example of how efficient fine-tuning and distillation methods can deliver robust conversational AI capabilities in a smaller, more manageable footprint.