Spaces:
Running
Running
import numpy as np | |
import pandas as pd # type: ignore | |
import os | |
import keras | |
import tensorflow as tf | |
from tensorflow.keras.models import load_model | |
import pymongo | |
import streamlit as st | |
from sentence_transformers import SentenceTransformer | |
from langchain.prompts import ChatPromptTemplate | |
from langchain_openai import ChatOpenAI | |
from langchain.schema.runnable import RunnablePassthrough | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain_core.messages import HumanMessage, SystemMessage | |
from PIL import Image | |
import json | |
import matplotlib.pyplot as plt | |
from matplotlib.colors import LinearSegmentedColormap | |
import textwrap | |
import plotly.graph_objects as go | |
st.set_page_config( | |
page_title="Food Chain", | |
page_icon="π΄", | |
layout="wide" | |
) | |
# Main App | |
if "theme_mode" not in st.session_state: | |
st.session_state.theme_mode = st.get_option("theme.base") | |
# Check for changes in theme mode | |
current_theme_mode = st.get_option("theme.base") | |
if current_theme_mode != st.session_state.theme_mode: | |
st.session_state.theme_mode = current_theme_mode | |
st.experimental_rerun() | |
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
mongo_uri = os.getenv("MONGO_URI_RAG_RECIPE") | |
def loadEmbedding(): | |
embedding = SentenceTransformer("thenlper/gte-large") | |
return embedding | |
embedding = loadEmbedding() | |
def getEmbedding(text): | |
if not text.strip(): | |
print("Text was empty") | |
return [] | |
encoded = embedding.encode(text) | |
return encoded.tolist() | |
# Connect to MongoDB | |
def get_mongo_client(mongo_uri): | |
try: | |
client = pymongo.MongoClient(mongo_uri) | |
print("Connection to MongoDB successful") | |
return client | |
except pymongo.errors.ConnectionFailure as e: | |
print(f"Connection failed: {e}") | |
return None | |
if not mongo_uri: | |
print("MONGO_URI not set in env") | |
mongo_client = get_mongo_client(mongo_uri) | |
mongo_db = mongo_client['recipes'] | |
mongo_collection = mongo_db['recipesCollection'] | |
def vector_search(user_query, collection): | |
query_embedding = getEmbedding(user_query) | |
if query_embedding is None: | |
return "Invalid query or embedding gen failed" | |
vector_search_stage = { | |
"$vectorSearch": { | |
"index": "vector_index", | |
"queryVector": query_embedding, | |
"path": "embedding", | |
"numCandidates": 150, # Number of candidate matches to consider | |
"limit": 4 # Return top 4 matches | |
} | |
} | |
unset_stage = { | |
"$unset": "embedding" # Exclude the 'embedding' field from the results | |
} | |
project_stage = { | |
"$project": { | |
"_id": 0, # Exclude the _id field | |
"name": 1, | |
"minutes": 1, | |
"tags": 1, | |
"n_steps": 1, | |
"description": 1, | |
"ingredients": 1, | |
"n_ingredients": 1, | |
"formatted_nutrition": 1, | |
"formatted_steps": 1, | |
"score": { | |
"$meta": "vectorSearchScore" # Include the search score | |
} | |
} | |
} | |
pipeline = [vector_search_stage, unset_stage, project_stage] | |
results = mongo_collection.aggregate(pipeline) | |
return list(results) | |
def mongo_retriever(query): | |
print("mongo retriever query: ", query) | |
documents = vector_search(query, mongo_collection) | |
print("Documents Retrieved: ", documents) | |
return documents | |
template = """ | |
You are an assistant for generating results based on user questions. | |
Use the provided context to generate a result based on the following JSON format: | |
{{ | |
"name": "Recipe Name", | |
"minutes": 0, | |
"tags": [ | |
"tag1", | |
"tag2", | |
"tag3" | |
], | |
"n_steps": 0, | |
"description": "A GENERAL description of the recipe goes here.", | |
"ingredients": [ | |
"0 tablespoons ingredient1", | |
"0 cups ingredient2", | |
"0 teaspoons ingredient3" | |
], | |
"n_ingredients": 0, | |
"formatted_nutrition": [ | |
"Calorie : per serving", | |
"Total Fat : % daily value", | |
"Sugar : % daily value", | |
"Sodium : % daily value", | |
"Protein : % daily value", | |
"Saturated Fat : % daily value", | |
"Total Carbohydrate : % daily value" | |
], | |
"formatted_steps": [ | |
"1. Step 1 of the recipe.", | |
"2. Step 2 of the recipe.", | |
"3. Step 3 of the recipe." | |
] | |
}} | |
Instructions: | |
1. Focus on the user's specific request and avoid irrelevant ingredients or approaches. | |
2. Do not return anything other than the JSON. | |
3. Base the response on simple, healthy, and accessible ingredients and techniques. | |
4. Rewrite the description in third person | |
5. Include the ingredient amounts and say them in the steps. | |
6. If the query makes no sense when trying to connection to a real dish, return [] | |
7. RETURN NOTHING BUT THE JSON | |
When choosing a recipe from the context, FOLLOW these instructions: | |
1. The recipe should be makeable from scratch, using only proper ingredients and not other dishes or pre-made recipes | |
2. If the recipes from the context makes sense but do not match {question}, generate an amazing, specific recipe for {question} | |
with precise steps and measurements. Take some inspiration from context if availab.e | |
3. Following the above template. | |
4. If the query makes no sense when trying to connection to a real dish, return [] | |
5. RETURN NOTHING BUT THE JSON | |
Context: {context} | |
Question: {question} | |
""" | |
custom_rag_prompt = ChatPromptTemplate.from_template(template) | |
llm = ChatOpenAI( | |
model_name="hf:meta-llama/Llama-3.3-70B-Instruct", | |
api_key = os.environ.get('GLHF_API_KEY'), | |
base_url = 'https://glhf.chat/api/openai/v1', | |
temperature=0.2) | |
rag_chain = ( | |
{"context": mongo_retriever, "question": RunnablePassthrough()} | |
| custom_rag_prompt | |
| llm | |
| StrOutputParser() | |
) | |
def get_response(query): | |
if query: | |
print("get_response query: ", query) | |
return rag_chain.invoke(query) | |
return "" | |
############################################## | |
# Classifier | |
img_size = 224 | |
def loadModel(): | |
model = load_model('efficientnet-fine-d1.keras') | |
return model | |
model = loadModel() | |
class_names = [ | |
"apple_pie", "baby_back_ribs", "baklava", "beef_carpaccio", "beef_tartare", "beet_salad", | |
"beignets", "bibimbap", "bread_pudding", "breakfast_burrito", "bruschetta", "caesar_salad", | |
"cannoli", "caprese_salad", "carrot_cake", "ceviche", "cheese_plate", "cheesecake", "chicken_curry", | |
"chicken_quesadilla", "chicken_wings", "chocolate_cake", "chocolate_mousse", "churros", "clam_chowder", | |
"club_sandwich", "crab_cakes", "creme_brulee", "croque_madame", "cup_cakes", "deviled_eggs", "donuts", | |
"dumplings", "edamame", "eggs_benedict", "escargots", "falafel", "filet_mignon", "fish_and_chips", "foie_gras", | |
"french_fries", "french_onion_soup", "french_toast", "fried_calamari", "fried_rice", "frozen_yogurt", | |
"garlic_bread", "gnocchi", "greek_salad", "grilled_cheese_sandwich", "grilled_salmon", "guacamole", "gyoza", | |
"hamburger", "hot_and_sour_soup", "hot_dog", "huevos_rancheros", "hummus", "ice_cream", "lasagna", | |
"lobster_bisque", "lobster_roll_sandwich", "macaroni_and_cheese", "macarons", "miso_soup", "mussels", | |
"nachos", "omelette", "onion_rings", "oysters", "pad_thai", "paella", "pancakes", "panna_cotta", "peking_duck", | |
"pho", "pizza", "pork_chop", "poutine", "prime_rib", "pulled_pork_sandwich", "ramen", "ravioli", "red_velvet_cake", | |
"risotto", "samosa", "sashimi", "scallops", "seaweed_salad", "shrimp_and_grits", "spaghetti_bolognese", | |
"spaghetti_carbonara", "spring_rolls", "steak", "strawberry_shortcake", "sushi", "tacos", "takoyaki", "tiramisu", | |
"tuna_tartare", "waffles" | |
] | |
def classifyImage(input_image): | |
input_image = input_image.resize((img_size, img_size)) | |
input_array = tf.keras.utils.img_to_array(input_image) | |
# Add a batch dimension | |
input_array = tf.expand_dims(input_array, 0) # (1, 224, 224, 3) | |
predictions = model.predict(input_array)[0] | |
print(f"Predictions: {predictions}") | |
# Sort predictions to get top 5 | |
top_indices = np.argsort(predictions)[-5:][::-1] | |
# Prepare the top 5 predictions with their class names and percentages | |
top_predictions = [(class_names[i], predictions[i] * 100) for i in top_indices] | |
for i, (class_name, confidence) in enumerate(top_predictions, 1): | |
print(f"{i}. Predicted {class_name} with {confidence:.1f}% Confidence") | |
return top_predictions | |
def capitalize_after_number(input_string): | |
# Split the string on the first period | |
if ". " in input_string: | |
num, text = input_string.split(". ", 1) | |
return f"{num}. {text.capitalize()}" | |
return input_string | |
############################################## | |
#for displaying RAG recipe response | |
def display_response(response): | |
""" | |
Function to format a JSON response into Streamlit's `st.write()` format. | |
""" | |
if response == "[]" or "": | |
st.write("No recipes found :(") | |
return | |
if isinstance(response, str): | |
# Convert JSON string to dictionary if necessary | |
response = json.loads(response) | |
with st.container(height=800): | |
st.write(f"**Name**: {response['name'].capitalize()}") | |
st.write(f"**Preparation Time**: {response['minutes']} minutes") | |
st.write(f"**Description**: {response['description'].capitalize()}") | |
st.write(f"**Tags**: {', '.join(response['tags'])}") | |
st.write("### Ingredients") | |
st.write(", ".join([ingredient.capitalize() for ingredient in response['ingredients']])) | |
st.write(f"**Total Ingredients**: {response['n_ingredients']}") | |
st.write("### Nutrition Information (per serving)") | |
st.write(", ".join(response['formatted_nutrition'])) | |
st.write(f"Number of Steps: {response['n_steps']}") | |
st.write("### Steps") | |
for step in response['formatted_steps']: | |
st.write(capitalize_after_number(step)) | |
# st.write(f"Name: {response['name'].capitalize()}") | |
# st.write(f"Preparation Time: {response['minutes']} minutes") | |
# st.write(f"Description: {response['description'].capitalize()}") | |
# st.write(f"Tags: {', '.join(response['tags'])}") | |
# st.write("### Ingredients") | |
# st.write(", ".join([ingredient.capitalize() for ingredient in response['ingredients']])) | |
# st.write(f"Total Ingredients: {response['n_ingredients']}") | |
# st.write("### Nutrition Information (per serving)") | |
# st.write(", ".join(response['formatted_nutrition'])) | |
# st.write(f"Number of Steps: {response['n_steps']}") | |
# st.write("### Steps") | |
# for step in response['formatted_steps']: | |
# st.write(capitalize_after_number(step)) | |
def display_dishes_in_grid(dishes, cols=3): | |
rows = len(dishes) // cols + int(len(dishes) % cols > 0) | |
for i in range(rows): | |
cols_data = dishes[i*cols:(i+1)*cols] | |
cols_list = st.columns(len(cols_data)) | |
for col, dish in zip(cols_list, cols_data): | |
with col: | |
st.sidebar.write(dish.replace("_", " ").capitalize()) | |
def display_prediction_graph(class_names, confidences): | |
# Create a list of labels and values from the predictions dictionary | |
values = [str(round(value, 1)) + "%" for value in confidences] | |
# Wrap class names if they are too long | |
class_names = [textwrap.fill(class_name, width=10) for class_name in class_names] | |
# Determine the top prediction | |
class_names.reverse() | |
# Determine the top prediction | |
values.reverse() | |
top_prediction = class_names[-1] | |
# Create a horizontal bar chart | |
fig = go.Figure(go.Bar( | |
x=values, | |
y=class_names, | |
orientation='h', | |
marker=dict(color='orange'), | |
text=values, # Display values on the bars | |
textposition='inside' # Position the text inside the bars | |
)) | |
# Update layout for better appearance | |
fig.update_layout( | |
title=f"Prediction: {top_prediction}", | |
margin=dict(l=20, r=20, t=60, b=20), | |
xaxis=dict( | |
showgrid=False, # No grid lines for the x-axis | |
ticks='', # No x-axis ticks | |
showticklabels=False # No x-axis tick labels | |
), | |
yaxis=dict( | |
showgrid=False # No grid lines for the y-axis | |
), | |
plot_bgcolor='rgba(0,0,0,0)', # No background color for the plot area | |
paper_bgcolor='rgba(0,0,0,0)', # No background color for the paper area | |
font=dict() # Default font color | |
) | |
# Display the chart in Streamlit | |
st.plotly_chart(fig) | |
# #Streamlit | |
#Left sidebar title | |
st.sidebar.markdown( | |
"<h1 style='font-size:32px;'>Food-Chain</h1>", | |
unsafe_allow_html=True | |
) | |
st.sidebar.write("Upload an image and/or enter a query to get started! Explore our trained dish types listed below for guidance.") | |
st.sidebar.markdown('### Food Classification') | |
uploaded_image = st.sidebar.file_uploader("Choose an image:", type="jpg") | |
st.sidebar.markdown('### RAG Recipe') | |
query = st.sidebar.text_area("Enter your query (optional):", height=100) | |
recipe_submit = st.sidebar.button(label='Chain Recipe', icon=':material/link:', use_container_width=True) | |
# gap | |
st.sidebar.markdown("<br><br>", unsafe_allow_html=True) | |
st.sidebar.markdown("### Dish Database") | |
selected_dish = st.sidebar.selectbox( | |
"Search for a dish that our model can classify:", | |
options=class_names, | |
index=0 | |
) | |
# Main title | |
st.title("Welcome to FOOD CHAIN!") | |
with st.expander("**What is FOOD CHAIN?**"): | |
st.markdown( | |
""" | |
The project aims to use machine learning and computer vision techniques to analyze food images | |
and identify them. By using diverse datasets, the model will learn to recognize dishes based on | |
visual features. Our project aims to inform users about what it is they are eating, including | |
potential nutritional value and an AI generated response on how their dish might have been prepared. | |
We want users to have an easy way to figure out what their favorite foods contain, to know any | |
allergens in the food and to better connect to the food around them. This tool can also tell users | |
the calories of their dish, they can figure out the nutrients with only a few steps! | |
Thank you for using our project! | |
Made by the Classify Crew: [Contact List](https://linktr.ee/classifycrew) | |
""" | |
) | |
################# | |
sample_RAG = { | |
"name": "Cinnamon Sugar Baked Donuts", | |
"minutes": 27, | |
"tags": [ | |
"30-minutes-or-less", | |
"time-to-make", | |
"course", | |
"cuisine", | |
"preparation", | |
"occasion", | |
"north-american", | |
"healthy", | |
"desserts", | |
"american", | |
"dietary", | |
"comfort-food", | |
"taste-mood" | |
], | |
"n_steps": 10, | |
"description": "A delightful treat with a crusty sugar-cinnamon coating, perfect for a weekend breakfast or snack. Leftovers freeze well.", | |
"ingredients": [ | |
"1 cup flour", | |
"1 teaspoon baking powder", | |
"1 teaspoon cinnamon", | |
"1/2 teaspoon nutmeg", | |
"1/4 teaspoon mace", | |
"1/4 teaspoon salt", | |
"1/2 cup sugar", | |
"1 egg", | |
"1/2 cup milk", | |
"2 tablespoons butter, melted", | |
"1 teaspoon vanilla", | |
"1/4 cup brown sugar" | |
], | |
"n_ingredients": 12, | |
"formatted_nutrition": [ | |
"Calorie : 302.9 per serving", | |
"Total Fat : 11.0 % daily value", | |
"Sugar : 154.0 % daily value", | |
"Sodium : 9.0 % daily value", | |
"Protein : 7.0 % daily value", | |
"Saturated Fat : 22.0 % daily value", | |
"Total Carbohydrate : 18.0 % daily value" | |
], | |
"formatted_steps": [ | |
"1. Mix all dry ingredients in a medium-size bowl", | |
"2. In a smaller bowl, beat the egg", | |
"3. Mix the egg with milk and melted butter", | |
"4. Add vanilla to the mixture", | |
"5. Stir the milk mixture into the dry ingredients until just combined, being careful not to overmix", | |
"6. Pour the batter into a greased donut baking tin, filling approximately 3/4 full", | |
"7. Mix cinnamon into brown sugar and sprinkle over the donuts", | |
"8. Drizzle or spoon melted butter over the top of each donut", | |
"9. Bake in a 350-degree oven for 17 minutes", | |
"10. Enjoy!" | |
] | |
} | |
col1, col2 = st.columns(2) | |
with col1: | |
st.title("Image Classification") | |
if not uploaded_image: | |
placeholder = Image.open("dish-placeholder.jpg") | |
st.image(placeholder, caption="Placeholder Image.", use_container_width=True) | |
sample_class_names = ['Donuts', 'Onion Rings', 'Beignets', 'Churros', 'Cup Cakes'] | |
sample_confidences = [98.1131911277771, 1.3879689387977123, 0.12678804341703653, 0.05296396557241678, 0.04436225863173604] | |
display_prediction_graph(sample_class_names, sample_confidences) | |
else: | |
# Open and display image | |
input_image = Image.open(uploaded_image) | |
st.image(input_image, caption="Uploaded Image.", use_container_width=True) | |
with col2: | |
st.title('RAG Recipe') | |
if not recipe_submit: | |
display_response(sample_RAG) | |
# Image Classification Section | |
if recipe_submit and uploaded_image: | |
with col1: | |
predictions = classifyImage(input_image) | |
print("Predictions: ", predictions) | |
# graph variables | |
fpredictions = "" | |
class_names = [] | |
confidences = [] | |
# Show the top predictions with percentages | |
# st.write("Top Predictions:") | |
for class_name, confidence in predictions: | |
fpredictions += f"{class_name}: {confidence:.1f}%," | |
class_name = class_name.replace("_", " ") | |
class_name = class_name.title() | |
# st.markdown(f"*{class_name}*: {confidence:.2f}%") | |
class_names.append(class_name) | |
confidences.append(confidence) | |
print(fpredictions) | |
display_prediction_graph(class_names, confidences) | |
# call openai to pick the best classification result based on query | |
openAICall = [ | |
SystemMessage( | |
content = "You are a helpful assistant that identifies the best match between classified food items and a user's request based on provided classifications and keywords." | |
), | |
HumanMessage( | |
content = f""" | |
Based on the following image classification with percentages of each food: | |
{fpredictions} | |
And the following user request: | |
{query} | |
1. If the user's query relates to any of the classified predictions (even partially or conceptually), select the most relevant dish from the predictions. | |
2. If the query does not align with the predictions, disregard them and suggest a dish that best matches the user's query. | |
3. Consider culture, ingredients, cooking steps, etc. | |
4. Return in the format: [dish] | |
5. ONLY return the name of the dish in brackets. | |
Example 1: | |
Predictions: apple pie: 50%, cherry tart: 30%, vanilla ice cream: 20% | |
User query: pumpkin | |
YOUR Response: [pumpkin pie] | |
Example 2: | |
Predictions: spaghetti: 60%, lasagna: 30%, salad: 10% | |
User query: pasta with layers | |
YOUR Response: [lasagna] | |
Example 3: | |
Predictions: sushi: 70%, sashimi: 20%, ramen: 10% | |
User query: noodles | |
YOUR Response: [ramen] | |
""" | |
), | |
] | |
with col2, st.spinner("Generating..."): | |
if query: | |
# Call the OpenAI API | |
openAIresponse = llm.invoke(openAICall) | |
print("AI CALL RESPONSE: ", openAIresponse.content, "END AI CALL RESONSE") | |
RAGresponse = get_response(openAIresponse.content + " " + query) | |
else: | |
RAGresponse = get_response(predictions[0][0]) | |
print("RAGresponse: ", RAGresponse) | |
display_response(RAGresponse) | |
elif recipe_submit and query: | |
with col2, st.spinner("Generating..."): | |
response = get_response(query) | |
print(response) | |
display_response(response) | |
else: | |
st.warning('Please input an image or a query.', icon="π") | |