In [5]:
model_name_or_path = "mistralai/Mixtral-8x7B-Instruct-v0.1"#@param {type:"string"}
experts_extract_bit = "10101010" #@param {type:"string"}
num_experts_per_tok = 2 #@param {type:"integer"}

temp_dir = "/content/drive/MyDrive/tf_models" #@param {type:"string"}
model_name = model_name_or_path.split("/")[-1]
target_dir = f"{temp_dir}/{model_name}"
save_dir   =  "/content/drive/MyDrive/tf_models/mx4x7b_x3" #@param {type:"string"}


experts_indexies = [i for i, bit in enumerate(experts_extract_bit) if bit == '1']
# print( experts_indexies )

if len(experts_extract_bit) != 8:
    raise ValueError("experts_extract_bit length must be 8")


In [None]:
!pip install git+https://github.com/huggingface/transformers --upgrade
!pip install torch accelerate bitsandbytes flash_attn sentencepiece protobuf

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%cd {temp_dir}
save_model_dir = model_name.split('/')[-1]
!mkdir -p {save_model_dir}

!wget https://huggingface.co./{model_name_or_path}/resolve/main/config.json -O {save_model_dir}/config.json
!wget https://huggingface.co./{model_name_or_path}/resolve/main/model.safetensors.index.json -O {save_model_dir}/model.safetensors.index.json
!wget https://huggingface.co./{model_name_or_path}/resolve/main/generation_config.json -O {save_model_dir}/generation_config.json

for i in range(1,20):
    file_count_str = str(i).zfill(5)
    !wget https://huggingface.co./{model_name_or_path}/resolve/main/model-{file_count_str}-of-00019.safetensors?download=true -O {save_model_dir}/model-{file_count_str}-of-00019.safetensors

In [5]:
def download_tokenizer_model(save_tokenizer_dir):
    !wget https://huggingface.co./{model_name_or_path}/resolve/main/tokenizer.model -O {save_tokenizer_dir}/tokenizer.model


In [None]:
%cd {temp_dir}

import json
import re
import torch
from safetensors import safe_open
from safetensors.torch import save_file

# model-00001-of-00019.safetensors
# model.safetensors.index.json

# save tokenizer
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
tokenizer.save_pretrained(save_dir)

# save config
config_path = f"{target_dir}/config.json"
config = None
with open(config_path, "r") as f:
    config = json.load(f)
    config["num_experts_per_tok"] = num_experts_per_tok if len(experts_indexies) >= num_experts_per_tok else 1
    config["num_local_experts"] = len(experts_indexies)

# save config
with open(f"{save_dir}/config.json", "w") as f:
    json.dump(config, f, indent=2)


# weight
weight_map = {}
first_weights = ["lm_head.weight", "model.embed_tokens.weight", "model.norm.weight"]

# load weight map
bin_index_path = f"{target_dir}/model.safetensors.index.json"
with open(bin_index_path, "r") as f:
    weight_map = json.load(f)["weight_map"]

def tensor_load(file_name, map_location=None):
    tensors = {}
    with safe_open(file_name, framework="pt") as f:
        for k in f.keys():
            tensors[k] = f.get_tensor(k)
    return tensors

def get_weight_byte_size(weight):

    if isinstance(weight, torch.Tensor):
        weight_byte_size = weight.nelement() * weight.element_size()
    else:
        weight_byte_size = sum(p.nelement() * p.element_size() for p in weight.parameters())

    return weight_byte_size

# load weight map
layers = {}
for key in weight_map.keys():
    if key in first_weights:
        continue

    # keyが"model.layers.[0-9]+."にmatchする場合はlayers_listに追加する
    layer_str = re.match(r"model\.layers\.[0-9]+\.", key)[0]
    if layer_str:
        layer_no = re.findall(r"\d+",layer_str)
        layer_no = layer_no[0]
        if layer_no not in layers.keys():
            layers[layer_no] = []

        layers[layer_no].append({ "key":key, "file_name":weight_map[key] })

# new weight_map index
new_weight_map = {
  "metadata": {
    "total_size": 0
  },
  "weight_map": {
  }
}

# load tensors
total_size = 0
tensor_weights = {}
tensors = {}
current_file_name = ""

file_count = 0
file_count_str = str(file_count).zfill(5)

for key in first_weights:
    file_name = weight_map[key]
    if current_file_name != file_name:

        # load safetensor
        tensors = tensor_load(f"{target_dir}/{file_name}", map_location="cpu")
        current_file_name = file_name

    tensor_weights[key] = tensors[key]
    new_weight_map["weight_map"][key] = f"model-{file_count_str}.safetensors"

    # add weight size
    total_size += get_weight_byte_size(tensor_weights[key])

# save tensor
save_file(tensor_weights, f"{save_dir}/model-{file_count_str}.safetensors", metadata={"format":"pt"})
file_count += 1

layer_keys = sorted([ int(k) for k in layers.keys()])

for layer_no in layer_keys:
    print("starting layer:",layer_no)
    file_count_str = str(file_count).zfill(5)
    tensor_weights = {}

    stock_expert_weights = {}

    current_file_name = ""
    for info in layers[str(layer_no)]:
        file_name = info["file_name"]
        if current_file_name != file_name:
            print("Loading Tensors ", file_name)
            tensors = tensor_load(f"{target_dir}/{file_name}", map_location="cpu")
            current_file_name = file_name

        layer_key = info["key"]
        layer_weights = tensors[layer_key]

        if 'experts' in layer_key:

            lk = re.findall(r"block_sparse_moe[.]experts[.][0-9]+.w", layer_key)[0]
            exp_index = int( re.findall(r"\d+",lk)[0] )

            # select target experts
            if exp_index in experts_indexies:
                new_layer_key = re.sub(r"block_sparse_moe\.experts\.\d+\.w", f"block_sparse_moe.experts.{experts_indexies.index(exp_index)}.w", layer_key)

                tensor_weights[new_layer_key] = layer_weights

                # add weight size
                total_size += get_weight_byte_size(tensor_weights[new_layer_key])

                new_weight_map["weight_map"][new_layer_key] = f"model-{file_count_str}.safetensors"
                print("new experts", new_layer_key, tensor_weights[new_layer_key].shape, "from", layer_key)

        elif 'gate' in layer_key:
            print("slice gate ", experts_indexies, layer_weights.shape, f"-> ({len(experts_indexies)}, 4096)", layer_key)

            # slice gate
            tensor_weights[layer_key] = layer_weights[experts_indexies]

            # add weight size
            total_size += get_weight_byte_size(tensor_weights[layer_key])

            new_weight_map["weight_map"][layer_key] = f"model-{file_count_str}.safetensors"
            print(layer_key, tensor_weights[layer_key].shape)

        else:
            tensor_weights[layer_key] = layer_weights

            # add weight size
            total_size += get_weight_byte_size(tensor_weights[layer_key])

            new_weight_map["weight_map"][layer_key] = f"model-{file_count_str}.safetensors"
            print(layer_key, tensor_weights[layer_key].shape)

    # save tensor
    save_file(tensor_weights, f"{save_dir}/model-{file_count_str}.safetensors", metadata={"format":"pt"})
    print("Save Tensors ", f"{save_dir}/model-{file_count_str}.safetensors")
    file_count += 1

# save new_weight_map
new_weight_map["metadata"]["total_size"] = total_size
with open(f"{save_dir}/model.safetensors.index.json", "w") as f:
    json.dump(new_weight_map, f, indent=2)

# download tokenizer.model
download_tokenizer_model(save_dir)

print("Done.")


In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM, MixtralForCausalLM
import torch

model_name_or_path = save_dir

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = MixtralForCausalLM.from_pretrained(model_name_or_path, load_in_8bit=True)

text = "[INST] What was John Holt's vision on education? [/INST] "
# text = "[INST] What is the best anime? [/INST] "
inputs = tokenizer("<s> " + text, return_tensors="pt")

outputs = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


 [INST] What was John Holt's vision on education? [/INST] 10 John Holt's vision on education was to create a system that would allow students to learn at their own pace. He believed that this would help students to become better learners. He also wanted to provide a platform that would allow students to learn in a safe environment. He believed that this would help students to become better learners. He also wanted to provide a platform that would allow students to learn in a safe environment. He believed that this would help students to become better learners. He also wanted to provide a platform that would allow students to learn in a safe environment. He believed that this would help students to become
