from fastapi import FastAPI, APIRouter, HTTPException from pydantic import BaseModel from pathlib import Path import os import re from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings from langchain.prompts import PromptTemplate from langchain_together import Together from langchain.memory import ConversationBufferWindowMemory from langchain.chains import ConversationalRetrievalChain from langdetect import detect from googletrans import Translator, LANGUAGES # Set the API key for Together.ai os.environ["TOGETHER_AI_API"] = "1c27fe0df51a29edee1bec6b4b648b436cc80cf4ccc36f56de17272d9e663cbd" # Ensure proper cache directory is available for models os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache' # Initialize FastAPI app app = FastAPI() # Initialize FastAPI Router router = APIRouter() # Include the router in the app app.include_router(router, prefix="/lawgpt") bot_name = "LawGPT" # Lazy loading of large models (only load embeddings and index when required) embeddings = HuggingFaceEmbeddings( model_name="nomic-ai/nomic-embed-text-v1", model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"}, ) index_path = Path("models/index.faiss") if not index_path.exists(): raise FileNotFoundError("FAISS index not found. Please generate it and place it in 'models'.") # Load the FAISS index db = FAISS.load_local("models", embeddings, allow_dangerous_deserialization=True) db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4}) # Define the prompt template for the legal chatbot prompt_template = """[INST]This is a chat template and as a legal chatbot specializing in Indian Penal Code queries, your objective is to provide accurate and concise information. CONTEXT: {context} CHAT HISTORY: {chat_history} QUESTION: {question} ANSWER: [INST]""" prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"]) # Set up the LLM (Large Language Model) for the chatbot llm = Together( model="mistralai/Mistral-7B-Instruct-v0.2", temperature=0.5, max_tokens=1024, together_api_key=os.getenv("TOGETHER_AI_API"), ) # Set up memory for conversational context memory = ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True) # Create the conversational retrieval chain with the LLM and retriever qa_chain = ConversationalRetrievalChain.from_llm( llm=llm, memory=memory, retriever=db_retriever, combine_docs_chain_kwargs={"prompt": prompt}, ) # Translator instance (sync version) translator = Translator() # Input schema for chat requests class ChatRequest(BaseModel): question: str chat_history: str # Function to validate the input question def is_valid_question(question: str) -> bool: """ Validate the input question to ensure it is meaningful and related to Indian law or crime. """ question = question.strip() # Reject if the question is too short if len(question) < 3: return False # Reject if the question contains only numbers or symbols if re.match(r'^\d+$', question) or re.match(r'^[^a-zA-Z0-9\s]+$', question): return False # Define keywords related to Indian law and crime legal_keywords = [ "IPC", "CrPC", "section", "law", "crime", "penalty", "punishment", "legal", "court", "justice", "offense", "fraud", "murder", "theft", "bail", "arrest", "FIR", "judgment", "act", "contract", "constitutional", "habeas corpus", "petition", "rights", "lawyer", "advocate", "accused", "penal", "conviction", "sentence", "appeal", "trial", "witness" ] # Check if the question contains at least one legal keyword if not any(keyword.lower() in question.lower() for keyword in legal_keywords): return False return True # POST endpoint to handle chat requests @router.post("/chat/") async def chat(request: ChatRequest): try: # Detect language detected_lang = await translator.detect(request.question) detected_language = detected_lang.lang # Translate question to English question_translation = await translator.translate(request.question, src=detected_language, dest="en") question_in_english = question_translation.text # Validate translated question if not is_valid_question(question_in_english): return { "answer": "Please provide a valid legal question related to Indian laws.", "language": LANGUAGES.get(detected_language, "unknown") } # Prepare input for LLM inputs = {"question": question_in_english, "chat_history": request.chat_history} # Run LLM retrieval chain result = qa_chain(inputs) # Ensure response contains an answer if 'answer' not in result: raise ValueError("Missing 'answer' key in the result from qa_chain") # Translate response back to original language answer_translation = await translator.translate(result["answer"], src="en", dest=detected_language) answer_in_original_language = answer_translation.text return { "answer": answer_in_original_language, "language": LANGUAGES.get(detected_language, "unknown") } except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") # GET endpoint to check if the API is running @router.get("/") async def root(): return {"message": "LawGPT API is running."} # Make sure the router is included and FastAPI will handle the paths.