ani / app.py
aniudupa's picture
Update app.py
2f5e3d5 verified
from fastapi import FastAPI, APIRouter, HTTPException
from pydantic import BaseModel
from pathlib import Path
import os
import re
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain_together import Together
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain
from langdetect import detect
from googletrans import Translator, LANGUAGES
# Set the API key for Together.ai
os.environ["TOGETHER_AI_API"] = "1c27fe0df51a29edee1bec6b4b648b436cc80cf4ccc36f56de17272d9e663cbd"
# Ensure proper cache directory is available for models
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'
# Initialize FastAPI app
app = FastAPI()
# Initialize FastAPI Router
router = APIRouter()
# Include the router in the app
app.include_router(router, prefix="/lawgpt")
bot_name = "LawGPT"
# Lazy loading of large models (only load embeddings and index when required)
embeddings = HuggingFaceEmbeddings(
model_name="nomic-ai/nomic-embed-text-v1",
model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
)
index_path = Path("models/index.faiss")
if not index_path.exists():
raise FileNotFoundError("FAISS index not found. Please generate it and place it in 'models'.")
# Load the FAISS index
db = FAISS.load_local("models", embeddings, allow_dangerous_deserialization=True)
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4})
# Define the prompt template for the legal chatbot
prompt_template = """<s>[INST]This is a chat template and as a legal chatbot specializing in Indian Penal Code queries, your objective is to provide accurate and concise information.
CONTEXT: {context}
CHAT HISTORY: {chat_history}
QUESTION: {question}
ANSWER:
</s>[INST]"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])
# Set up the LLM (Large Language Model) for the chatbot
llm = Together(
model="mistralai/Mistral-7B-Instruct-v0.2",
temperature=0.5,
max_tokens=1024,
together_api_key=os.getenv("TOGETHER_AI_API"),
)
# Set up memory for conversational context
memory = ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True)
# Create the conversational retrieval chain with the LLM and retriever
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
memory=memory,
retriever=db_retriever,
combine_docs_chain_kwargs={"prompt": prompt},
)
# Translator instance (sync version)
translator = Translator()
# Input schema for chat requests
class ChatRequest(BaseModel):
question: str
chat_history: str
# Function to validate the input question
def is_valid_question(question: str) -> bool:
"""
Validate the input question to ensure it is meaningful and related to Indian law or crime.
"""
question = question.strip()
# Reject if the question is too short
if len(question) < 3:
return False
# Reject if the question contains only numbers or symbols
if re.match(r'^\d+$', question) or re.match(r'^[^a-zA-Z0-9\s]+$', question):
return False
# Define keywords related to Indian law and crime
legal_keywords = [
"IPC", "CrPC", "section", "law", "crime", "penalty", "punishment",
"legal", "court", "justice", "offense", "fraud", "murder", "theft",
"bail", "arrest", "FIR", "judgment", "act", "contract", "constitutional",
"habeas corpus", "petition", "rights", "lawyer", "advocate", "accused",
"penal", "conviction", "sentence", "appeal", "trial", "witness"
]
# Check if the question contains at least one legal keyword
if not any(keyword.lower() in question.lower() for keyword in legal_keywords):
return False
return True
# POST endpoint to handle chat requests
@router.post("/chat/")
async def chat(request: ChatRequest):
try:
# Detect language
detected_lang = await translator.detect(request.question)
detected_language = detected_lang.lang
# Translate question to English
question_translation = await translator.translate(request.question, src=detected_language, dest="en")
question_in_english = question_translation.text
# Validate translated question
if not is_valid_question(question_in_english):
return {
"answer": "Please provide a valid legal question related to Indian laws.",
"language": LANGUAGES.get(detected_language, "unknown")
}
# Prepare input for LLM
inputs = {"question": question_in_english, "chat_history": request.chat_history}
# Run LLM retrieval chain
result = qa_chain(inputs)
# Ensure response contains an answer
if 'answer' not in result:
raise ValueError("Missing 'answer' key in the result from qa_chain")
# Translate response back to original language
answer_translation = await translator.translate(result["answer"], src="en", dest=detected_language)
answer_in_original_language = answer_translation.text
return {
"answer": answer_in_original_language,
"language": LANGUAGES.get(detected_language, "unknown")
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
# GET endpoint to check if the API is running
@router.get("/")
async def root():
return {"message": "LawGPT API is running."}
# Make sure the router is included and FastAPI will handle the paths.