|
import gc |
|
import torch |
|
import streamlit as st |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import re |
|
import os |
|
|
|
MODELS = { |
|
"athena-1": { |
|
"name": "π§ Athena-1", |
|
"sizes": { |
|
"0.5B": "Spestly/Athena-1-0.5B", |
|
"1.5B": "Spestly/Athena-1-1.5B", |
|
}, |
|
"emoji": "π§ ", |
|
"experimental": False, |
|
}, |
|
"athena-2": { |
|
"name": "π Athena-2", |
|
"sizes": { |
|
"0.5B": "Spestly/Athena-2-0.5B", |
|
"1.5B": "Spestly/Athena-2-1.5B", |
|
}, |
|
"emoji": "π", |
|
"experimental": False, |
|
}, |
|
} |
|
|
|
class AthenaInferenceApp: |
|
def __init__(self): |
|
if "current_model" not in st.session_state: |
|
st.session_state.current_model = {"tokenizer": None, "model": None, "config": None} |
|
if "chat_history" not in st.session_state: |
|
st.session_state.chat_history = [] |
|
|
|
st.set_page_config( |
|
page_title="Athena Model Inference", |
|
page_icon="π€", |
|
layout="wide", |
|
menu_items={ |
|
'Get Help': 'https://huggingface.co./collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86', |
|
'Report a bug': 'https://huggingface.co./Spestly/Athena-1-1.5B/discussions/new', |
|
'About': 'Athena Model Inference Platform' |
|
} |
|
) |
|
|
|
def clear_memory(self): |
|
"""Optimize memory management for CPU inference""" |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
def load_model(self, model_key, model_size): |
|
try: |
|
self.clear_memory() |
|
|
|
if st.session_state.current_model["model"] is not None: |
|
del st.session_state.current_model["model"] |
|
del st.session_state.current_model["tokenizer"] |
|
self.clear_memory() |
|
|
|
model_path = MODELS[model_key]["sizes"][model_size] |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
device_map="cpu", |
|
torch_dtype=torch.float32, |
|
trust_remote_code=True, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
|
|
st.session_state.current_model.update({ |
|
"tokenizer": tokenizer, |
|
"model": model, |
|
"config": { |
|
"name": f"{MODELS[model_key]['name']} {model_size}", |
|
"path": model_path, |
|
} |
|
}) |
|
return f"β
{MODELS[model_key]['name']} {model_size} loaded successfully!" |
|
except Exception as e: |
|
return f"β Error: {str(e)}" |
|
|
|
def respond(self, message, max_tokens, temperature, top_p, top_k): |
|
if not st.session_state.current_model["model"]: |
|
return "β οΈ Please select and load a model first" |
|
|
|
try: |
|
|
|
system_instruction = "You are Athena, a helpful AI assistant trained by Spestly. You are a Qwen 2.5 fine-tune." |
|
prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:" |
|
|
|
inputs = st.session_state.current_model["tokenizer"]( |
|
prompt, |
|
return_tensors="pt", |
|
max_length=512, |
|
truncation=True, |
|
padding=True |
|
) |
|
|
|
with torch.no_grad(): |
|
output = st.session_state.current_model["model"].generate( |
|
input_ids=inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
do_sample=True, |
|
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id, |
|
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id, |
|
) |
|
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True) |
|
return response.split("### Response:")[-1].strip() |
|
except Exception as e: |
|
return f"β οΈ Generation Error: {str(e)}" |
|
finally: |
|
self.clear_memory() |
|
|
|
def main(self): |
|
st.title("π¦ AthenaUI") |
|
|
|
with st.sidebar: |
|
st.header("π Model Selection") |
|
|
|
model_key = st.selectbox( |
|
"Choose Athena Variant", |
|
list(MODELS.keys()), |
|
format_func=lambda x: f"{MODELS[x]['name']} {'π§ͺ' if MODELS[x]['experimental'] else ''}" |
|
) |
|
|
|
model_size = st.selectbox( |
|
"Choose Model Size", |
|
list(MODELS[model_key]["sizes"].keys()) |
|
) |
|
|
|
if st.button("Load Model"): |
|
with st.spinner("Loading model... This may take a few minutes."): |
|
status = self.load_model(model_key, model_size) |
|
st.success(status) |
|
|
|
st.header("π§ Generation Parameters") |
|
max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10) |
|
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1) |
|
top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1) |
|
top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1) |
|
|
|
if st.button("Clear Chat History"): |
|
st.session_state.chat_history = [] |
|
st.rerun() |
|
|
|
st.markdown("*π¬ Bored of Athena? Try Atlas-Flash and Atlas-Pro!*") |
|
|
|
for message in st.session_state.chat_history: |
|
with st.chat_message(message["role"]): |
|
st.markdown(message["content"]) |
|
|
|
if prompt := st.chat_input("Message Athena..."): |
|
st.session_state.chat_history.append({"role": "user", "content": prompt}) |
|
with st.chat_message("user"): |
|
st.markdown(prompt) |
|
|
|
with st.chat_message("assistant"): |
|
with st.spinner("Generating response..."): |
|
response = self.respond(prompt, max_tokens, temperature, top_p, top_k) |
|
st.markdown(response) |
|
|
|
st.session_state.chat_history.append({"role": "assistant", "content": response}) |
|
|
|
def run(): |
|
try: |
|
app = AthenaInferenceApp() |
|
app.main() |
|
except Exception as e: |
|
st.error(f"β οΈ Application Error: {str(e)}") |
|
|
|
if __name__ == "__main__": |
|
run() |