Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import os | |
from dotenv import load_dotenv | |
from functools import lru_cache | |
# Load environment variables | |
load_dotenv() | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# App title and description | |
st.title("I am Your GrowBuddy 🌱") | |
st.write("Let me help you start gardening. Let's grow together!") | |
# Function to load model only once (with quantization for CPU optimization) | |
def load_model(): | |
try: | |
tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN, use_fast=True) | |
# Quantized model for better CPU performance (with 8-bit precision) | |
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN, torch_dtype=torch.float32) | |
return tokenizer, model | |
except Exception as e: | |
st.error(f"Failed to load model: {e}") | |
return None, None | |
# Load model and tokenizer (cached) | |
tokenizer, model = load_model() | |
if not tokenizer or not model: | |
st.stop() | |
# Ensure model is on CPU (set to float32 for better performance on CPU) | |
device = torch.device("cpu") | |
model = model.to(device) | |
# Initialize session state messages | |
if "messages" not in st.session_state: | |
st.session_state.messages = [ | |
{"role": "assistant", "content": "Hello there! How can I help you with gardening today?"} | |
] | |
# Display conversation history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# LRU Cache for repeated queries to avoid redundant computation | |
def cached_generate_response(prompt, tokenizer, model): | |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) | |
outputs = model.generate(inputs["input_ids"], max_new_tokens=50, temperature=0.7, do_sample=True) | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return response | |
# Function to generate response with optimization | |
def generate_response(prompt): | |
try: | |
# Check cache for previous result (for repeated queries) | |
cached_response = cached_generate_response(prompt, tokenizer, model) | |
return cached_response | |
except Exception as e: | |
st.error(f"Error during text generation: {e}") | |
return "Sorry, I couldn't process your request." | |
# User input field for gardening questions | |
user_input = st.chat_input("Type your gardening question here:") | |
if user_input: | |
with st.chat_message("user"): | |
st.write(user_input) | |
with st.chat_message("assistant"): | |
with st.spinner("Generating your answer..."): | |
response = generate_response(user_input) | |
st.write(response) | |
# Update session state with new messages | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |