--- license: apache-2.0 language: - en inference: false --- # Model Card for TinyMixtral-x8-Clonebase-7b This model is based on [TinyLlama-1.1B](https://huggingface.co./TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T), converted to a mistral model, and then placed the clone in mixtral. **This model was created experimentally for training a small mixtral.** **Without Train, the performance of this model is the same as TinyLlama.** # How it was made First, since tinyllama is an llama model, I converted it to a mistral model. After that, I cloned the FFN part and made it experts. Since they are all the same tensor, the performance does not change. All gates have the same value. # How To Convert use colab cpu-high-memory. This model was created with experts=8, but since it is a clone, you can create as many experts as you like. [tinyllama_to_mixtral_clonebase.ipynb](https://huggingface.co./mmnga/TinyMixtral-x8-Clonebase-7b/blob/main/notebook/tinyllama_to_mixtral_clonebase.ipynb) # revision [main TinyLlama-1.1B-intermediate-step-1431k-3T](https://huggingface.co./mmnga/TinyMixtral-x8-Clonebase-7b) [old TinyLlama-1.1B-intermediate-step-1195k-token-2.5T](https://huggingface.co./mmnga/TinyMixtral-x8-Clonebase-7b/tree/TinyLlama-1.1B-intermediate-step-1195k-token-2.5T) # Usage ~~~python pip install transformers --upgrade pip install flash_attn bitsandbytes accelerate ~~~ ~~~python from transformers import AutoTokenizer, AutoModelForCausalLM import torch model_name_or_path = "mmnga/TinyMixtral-x8-Clonebase-7b" tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto", load_in_8bit=True) prompt = "Introducing the recipe for today's dinner." with torch.no_grad(): token_ids = tokenizer.encode(prompt, return_tensors="pt") output_ids = model.generate( token_ids.to(model.device), do_sample=True, max_new_tokens=128, repetition_penalty=1.5 ) output = tokenizer.decode(output_ids[0]) print(output) ~~~