Spaces:
Runtime error
Runtime error
import os | |
import re | |
import gradio as gr | |
from dotenv import load_dotenv | |
from langchain_community.utilities import SQLDatabase | |
from langchain_openai import ChatOpenAI | |
from langchain.chains import create_sql_query_chain | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers.openai_tools import PydanticToolsParser | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from typing import List | |
import sqlite3 | |
from langsmith import traceable | |
from openai import OpenAI | |
# Load environment variables from .env file | |
load_dotenv() | |
# Set up LangSmith | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY") | |
os.environ["LANGCHAIN_PROJECT"] = "SQLq&a" | |
# Initialize OpenAI client | |
openai_client = OpenAI() | |
# Set up the database connection | |
db_path = os.path.join(os.path.dirname(__file__), "chinook.db") | |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
# Function to get table info | |
def get_table_info(db_path): | |
conn = sqlite3.connect(db_path) | |
cursor = conn.cursor() | |
# Get all table names | |
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
tables = cursor.fetchall() | |
table_info = {} | |
for table in tables: | |
table_name = table[0] | |
cursor.execute(f"PRAGMA table_info({table_name})") | |
columns = cursor.fetchall() | |
column_names = [column[1] for column in columns] | |
table_info[table_name] = column_names | |
conn.close() | |
return table_info | |
# Get table info | |
table_info = get_table_info(db_path) | |
# Format table info for display | |
def format_table_info(table_info): | |
info_str = f"Total number of tables: {len(table_info)}\n\n" | |
info_str += "Tables and their columns:\n\n" | |
for table, columns in table_info.items(): | |
info_str += f"{table}:\n" | |
for column in columns: | |
info_str += f" - {column}\n" | |
info_str += "\n" | |
return info_str | |
# Initialize the language model | |
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0) | |
class Table(BaseModel): | |
"""Table in SQL database.""" | |
name: str = Field(description="Name of table in SQL database.") | |
# Create the table selection prompt | |
table_names = "\n".join(db.get_usable_table_names()) | |
system = f"""Return the names of ALL the SQL tables that MIGHT be relevant to the user question. \ | |
The tables are: | |
{table_names} | |
Remember to include ALL POTENTIALLY RELEVANT tables, even if you're not sure that they're needed.""" | |
table_prompt = ChatPromptTemplate.from_messages([ | |
("system", system), | |
("human", "{input}"), | |
]) | |
llm_with_tools = llm.bind_tools([Table]) | |
output_parser = PydanticToolsParser(tools=[Table]) | |
table_chain = table_prompt | llm_with_tools | output_parser | |
# Function to get table names from the output | |
def get_table_names(output: List[Table]) -> List[str]: | |
return [table.name for table in output] | |
# Create the SQL query chain | |
query_chain = create_sql_query_chain(llm, db) | |
# Combine table selection and query generation | |
full_chain = ( | |
RunnablePassthrough.assign( | |
table_names_to_use=lambda x: get_table_names(table_chain.invoke({"input": x["question"]})) | |
) | |
| query_chain | |
) | |
# Function to strip markdown formatting from SQL query | |
def strip_markdown(text): | |
# Remove code block formatting | |
text = re.sub(r'```sql\s*|\s*```', '', text) | |
# Remove any leading/trailing whitespace | |
return text.strip() | |
# Function to execute SQL query | |
def execute_query(query: str) -> str: | |
try: | |
# Strip markdown formatting before executing | |
clean_query = strip_markdown(query) | |
result = db.run(clean_query) | |
return str(result) | |
except Exception as e: | |
return f"Error executing query: {str(e)}" | |
# Create the answer generation prompt | |
answer_prompt = ChatPromptTemplate.from_messages([ | |
("system", """Given the following user question, corresponding SQL query, and SQL result, answer the user question. | |
If there was an error in executing the SQL query, please explain the error and suggest a correction. | |
Do not include any SQL code formatting or markdown in your response. | |
Here is the database schema for reference: | |
{table_info}"""), | |
("human", "Question: {question}\nSQL Query: {query}\nSQL Result: {result}\nAnswer:") | |
]) | |
# Assemble the final chain | |
chain = ( | |
RunnablePassthrough.assign(query=lambda x: full_chain.invoke(x)) | |
.assign(result=lambda x: execute_query(x["query"])) | |
| answer_prompt | |
| llm | |
| StrOutputParser() | |
) | |
# Function to process user input and generate response | |
def process_input(message, history, table_info_str): | |
response = chain.invoke({"question": message, "table_info": table_info_str}) | |
return response | |
# Formatted table info | |
formatted_table_info = format_table_info(table_info) | |
# Create Gradio interface | |
iface = gr.ChatInterface( | |
fn=process_input, | |
title="SQL Q&A Chatbot for Chinook Database", | |
description="Ask questions about the Chinook music store database and get answers!", | |
examples=[ | |
["Who are the top 5 artists with the most albums in the database?"], | |
["What is the total sales amount for each country?"], | |
["Which employee has made the highest total sales, and what is the amount?"], | |
["What are the top 10 longest tracks in the database, and who are their artists?"], | |
["How many customers are there in each country, and what is the total sales for each?"] | |
], | |
additional_inputs=[ | |
gr.Textbox( | |
label="Database Schema", | |
value=formatted_table_info, | |
lines=10, | |
max_lines=20, | |
interactive=False | |
) | |
], | |
theme="soft" | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() |