AthenaUI / app.py
Spestly's picture
Update app.py
e8a623a verified
import gc
import torch
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import os
MODELS = {
"athena-1": {
"name": "🧠 Athena-1",
"sizes": {
"0.5B": "Spestly/Athena-1-0.5B",
"1.5B": "Spestly/Athena-1-1.5B",
},
"emoji": "🧠",
"experimental": False,
},
"athena-2": {
"name": "πŸš€ Athena-2",
"sizes": {
"0.5B": "Spestly/Athena-2-0.5B",
"1.5B": "Spestly/Athena-2-1.5B",
},
"emoji": "πŸš€",
"experimental": False,
},
}
class AthenaInferenceApp:
def __init__(self):
if "current_model" not in st.session_state:
st.session_state.current_model = {"tokenizer": None, "model": None, "config": None}
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
st.set_page_config(
page_title="Athena Model Inference",
page_icon="πŸ€–",
layout="wide",
menu_items={
'Get Help': 'https://huggingface.co./collections/Spestly/athena-1-67623e58bfaadd3c2fcffb86',
'Report a bug': 'https://huggingface.co./Spestly/Athena-1-1.5B/discussions/new',
'About': 'Athena Model Inference Platform'
}
)
def clear_memory(self):
"""Optimize memory management for CPU inference"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
def load_model(self, model_key, model_size):
try:
self.clear_memory()
if st.session_state.current_model["model"] is not None:
del st.session_state.current_model["model"]
del st.session_state.current_model["tokenizer"]
self.clear_memory()
model_path = MODELS[model_key]["sizes"][model_size]
# Load Qwen-compatible tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_path,
device_map="cpu", # Force CPU usage
torch_dtype=torch.float32, # Use float32 for CPU
trust_remote_code=True,
low_cpu_mem_usage=True
)
# Update session state
st.session_state.current_model.update({
"tokenizer": tokenizer,
"model": model,
"config": {
"name": f"{MODELS[model_key]['name']} {model_size}",
"path": model_path,
}
})
return f"βœ… {MODELS[model_key]['name']} {model_size} loaded successfully!"
except Exception as e:
return f"❌ Error: {str(e)}"
def respond(self, message, max_tokens, temperature, top_p, top_k):
if not st.session_state.current_model["model"]:
return "⚠️ Please select and load a model first"
try:
# Add a system instruction to guide the model's behavior
system_instruction = "You are Athena, a helpful AI assistant trained by Spestly. You are a Qwen 2.5 fine-tune."
prompt = f"{system_instruction}\n\n### Instruction:\n{message}\n\n### Response:"
inputs = st.session_state.current_model["tokenizer"](
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
with torch.no_grad():
output = st.session_state.current_model["model"].generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=st.session_state.current_model["tokenizer"].pad_token_id,
eos_token_id=st.session_state.current_model["tokenizer"].eos_token_id,
)
response = st.session_state.current_model["tokenizer"].decode(output[0], skip_special_tokens=True)
return response.split("### Response:")[-1].strip() # Extract the response
except Exception as e:
return f"⚠️ Generation Error: {str(e)}"
finally:
self.clear_memory()
def main(self):
st.title("πŸ¦‰ AthenaUI")
with st.sidebar:
st.header("πŸ›  Model Selection")
model_key = st.selectbox(
"Choose Athena Variant",
list(MODELS.keys()),
format_func=lambda x: f"{MODELS[x]['name']} {'πŸ§ͺ' if MODELS[x]['experimental'] else ''}"
)
model_size = st.selectbox(
"Choose Model Size",
list(MODELS[model_key]["sizes"].keys())
)
if st.button("Load Model"):
with st.spinner("Loading model... This may take a few minutes."):
status = self.load_model(model_key, model_size)
st.success(status)
st.header("πŸ”§ Generation Parameters")
max_tokens = st.slider("Max New Tokens", min_value=10, max_value=512, value=256, step=10)
temperature = st.slider("Temperature", min_value=0.1, max_value=2.0, value=0.4, step=0.1)
top_p = st.slider("Top-P", min_value=0.1, max_value=1.0, value=0.9, step=0.1)
top_k = st.slider("Top-K", min_value=1, max_value=100, value=50, step=1)
if st.button("Clear Chat History"):
st.session_state.chat_history = []
st.rerun()
st.markdown("*πŸ’¬ Bored of Athena? Try Atlas-Flash and Atlas-Pro!*")
for message in st.session_state.chat_history:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Message Athena..."):
st.session_state.chat_history.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Generating response..."):
response = self.respond(prompt, max_tokens, temperature, top_p, top_k)
st.markdown(response)
st.session_state.chat_history.append({"role": "assistant", "content": response})
def run():
try:
app = AthenaInferenceApp()
app.main()
except Exception as e:
st.error(f"⚠️ Application Error: {str(e)}")
if __name__ == "__main__":
run()