Bitnet-Nous-Llama3-225M 🚀

Este modelo es una variante optimizada del Llama3 utilizando la arquitectura BitNet, lo que reduce los pesos a los valores -1, 0, y 1 para mejorar la eficiencia en el cómputo sin perder precisión.

image/png

Modelo Base 🦙

Arquitectura 🔧

El modelo transforma las capas lineales de Llama3 en capas BitLinear, aprovechando las siguientes técnicas de cuantización:

  • Cuantización de activaciones: Escala a ±127
  • Cuantización de pesos: Escala a ±1

Especificaciones Técnicas 📋

  • Dimensiones: 768
  • Capas: 6
  • Contexto: 256 tokens
  • Tamaño intermedio: 1024
  • Número de cabezas de atención: 6

Dataset 📚

El modelo fue entrenado usando el dataset Cosmopedia-100k-pretrain, que contiene una variedad de datos de texto.

Entrenamiento ⚙️

El modelo fue entrenado con la siguiente configuración:

  • Lote: 16
  • Tasa de aprendizaje: 1.5e-4
  • Épocas: 2
  • Acumulación de gradientes: 2 pasos
  • Decaimiento de pesos: 0.01
  • Precisión Mixta: FP16

Monitoreo 📊

El proceso de entrenamiento fue monitoreado usando Weights & Biases.

Uso del Modelo 💻

Para usar este modelo, puedes cargarlo desde Hugging Face con el siguiente código:

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.llama.modeling_llama import *
import torch
from torch import nn
import torch.nn.functional as F
import coloredlogs
import logging

from utils.utils import count_parameters

coloredlogs.install(level='INFO', fmt='%(asctime)s - %(levelname)s - %(message)s', logger=logging.getLogger())
logger = logging.getLogger(__name__)




HF_TOKEN = "tuclaveaqui"
#model = "ejbejaranos/Bitnet-Llama3-from8BM-now2B"
model = "ejbejaranos/Bitnet-Nous-Llama3-225M" ## Working

# Load a pretrained BitNet model
tokenizer = AutoTokenizer.from_pretrained(model)

model = AutoModelForCausalLM.from_pretrained(
    model,
    token=HF_TOKEN
)


def count_parameters(model):
    # Calculate the number of parameters in billions
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 10**9
    print(f"Model size: {num_params:.3f}B parameters")
    return int(num_params)



# Establece el pad_token_id
model.config.pad_token_id = tokenizer.eos_token_id

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127)
    y = y / scale
    return y

def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1)
    u = u / scale
    return u

class BitLinear(nn.Linear):
    def forward(self, x):
        w = self.weight  # a weight tensor with shape [d, k]
        x = x.to(w.device)
        RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)
        x_norm = RMSNorm(x)
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        y = F.linear(x_quant, w_quant)
        return y

def convert_to_bitnet(model, copy_weights):
    for name, module in model.named_modules():
        if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):
            for child_name, child_module in module.named_children():
                if isinstance(child_module, nn.Linear):
                    bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device="cuda:0")
                    if copy_weights:
                        bitlinear.weight = child_module.weight
                        if child_module.bias is not None:
                            bitlinear.bias = child_module.bias
                    setattr(module, child_name, bitlinear)
        elif isinstance(module, LlamaDecoderLayer):
            for child_name, child_module in module.named_children():
                if isinstance(child_module, LlamaRMSNorm) and child_name == "input_layernorm":
                    setattr(module, child_name, nn.Identity().to(device="cuda:0"))

convert_to_bitnet(model, copy_weights=True)
model.to(device="cuda:0")


logger.info(f"🔢 Number of parameters in the model after extracting weights: {count_parameters(model)}")
logger.info(f"📏 Reduced model structure:\n{model}")





prompt = "What is Machine Learning?"
inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)
inputs['attention_mask'] = inputs['input_ids'] != model.config.pad_token_id

generate_ids = model.generate(inputs.input_ids, attention_mask=inputs['attention_mask'], max_length=250)
decoded_output = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(decoded_output[0])  # Print the generated response
Downloads last month
3
Safetensors
Model size
225M params
Tensor type
F32
·
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train ITCL/Llama3-8B-Bitnet-now-225M