import streamlit as st from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os from dotenv import load_dotenv from functools import lru_cache # Load environment variables load_dotenv() HF_TOKEN = os.getenv("HF_TOKEN") # App title and description st.title("I am Your GrowBuddy 🌱") st.write("Let me help you start gardening. Let's grow together!") # Function to load model only once (with quantization for CPU optimization) @st.cache_resource def load_model(): try: tokenizer = AutoTokenizer.from_pretrained("TheSheBots/UrbanGardening", use_auth_token=HF_TOKEN, use_fast=True) # Quantized model for better CPU performance (with 8-bit precision) model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it", use_auth_token=HF_TOKEN, torch_dtype=torch.float32) return tokenizer, model except Exception as e: st.error(f"Failed to load model: {e}") return None, None # Load model and tokenizer (cached) tokenizer, model = load_model() if not tokenizer or not model: st.stop() # Ensure model is on CPU (set to float32 for better performance on CPU) device = torch.device("cpu") model = model.to(device) # Initialize session state messages if "messages" not in st.session_state: st.session_state.messages = [ {"role": "assistant", "content": "Hello there! How can I help you with gardening today?"} ] # Display conversation history for message in st.session_state.messages: with st.chat_message(message["role"]): st.write(message["content"]) # LRU Cache for repeated queries to avoid redundant computation @lru_cache(maxsize=100) def cached_generate_response(prompt, tokenizer, model): inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device) outputs = model.generate(inputs["input_ids"], max_new_tokens=50, temperature=0.7, do_sample=True) response = tokenizer.decode(outputs[0], skip_special_tokens=True) return response # Function to generate response with optimization def generate_response(prompt): try: # Check cache for previous result (for repeated queries) cached_response = cached_generate_response(prompt, tokenizer, model) return cached_response except Exception as e: st.error(f"Error during text generation: {e}") return "Sorry, I couldn't process your request." # User input field for gardening questions user_input = st.chat_input("Type your gardening question here:") if user_input: with st.chat_message("user"): st.write(user_input) with st.chat_message("assistant"): with st.spinner("Generating your answer..."): response = generate_response(user_input) st.write(response) # Update session state with new messages st.session_state.messages.append({"role": "user", "content": user_input}) st.session_state.messages.append({"role": "assistant", "content": response})