|
import os |
|
import re |
|
|
|
import torch |
|
from modeling_jamba import JambaForCausalLM |
|
|
|
output_dir = "/home/user/jamba-small" |
|
|
|
model = JambaForCausalLM.from_pretrained("ai21labs/Jamba-v0.1", device_map="cpu", torch_dtype=torch.bfloat16) |
|
|
|
def prune_and_copy_additional_layers(original_state_dict): |
|
layer_mapping = { |
|
0: 0, |
|
1: 1, |
|
2: 2, |
|
3: 2, |
|
4: 4, |
|
5: 5, |
|
6: 30, |
|
7: 31 |
|
} |
|
|
|
new_state_dict = {} |
|
|
|
|
|
for new_idx, orig_idx in layer_mapping.items(): |
|
prefix = f"model.layers.{orig_idx}" |
|
for key, value in original_state_dict.items(): |
|
if key.startswith(prefix): |
|
new_key = key.replace(f"layers.{orig_idx}", f"layers.{new_idx}") |
|
new_state_dict[new_key] = value |
|
|
|
global_keys = ['model.embed_tokens.weight', 'model.final_layernorm.weight', 'lm_head.weight'] |
|
for key in global_keys: |
|
new_state_dict[key] = original_state_dict[key] |
|
|
|
return new_state_dict |
|
|
|
pruned_state_dict = prune_and_copy_additional_layers(model.state_dict()) |
|
|
|
print("Saving weights...") |
|
torch.save(pruned_state_dict, output_dir) |
|
print("Done!") |