|
from fastapi import FastAPI, APIRouter, HTTPException |
|
from pydantic import BaseModel |
|
from pathlib import Path |
|
import os |
|
import re |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
from langchain_community.vectorstores import FAISS |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain.prompts import PromptTemplate |
|
from langchain_together import Together |
|
from langchain.memory import ConversationBufferWindowMemory |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langdetect import detect |
|
from googletrans import Translator, LANGUAGES |
|
|
|
|
|
os.environ["TOGETHER_AI_API"] = "1c27fe0df51a29edee1bec6b4b648b436cc80cf4ccc36f56de17272d9e663cbd" |
|
|
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache' |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
router = APIRouter() |
|
|
|
|
|
app.include_router(router, prefix="/lawgpt") |
|
|
|
bot_name = "LawGPT" |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
model_name="nomic-ai/nomic-embed-text-v1", |
|
model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"}, |
|
) |
|
|
|
index_path = Path("models/index.faiss") |
|
if not index_path.exists(): |
|
raise FileNotFoundError("FAISS index not found. Please generate it and place it in 'models'.") |
|
|
|
|
|
db = FAISS.load_local("models", embeddings, allow_dangerous_deserialization=True) |
|
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4}) |
|
|
|
|
|
prompt_template = """<s>[INST]This is a chat template and as a legal chatbot specializing in Indian Penal Code queries, your objective is to provide accurate and concise information. |
|
CONTEXT: {context} |
|
CHAT HISTORY: {chat_history} |
|
QUESTION: {question} |
|
ANSWER: |
|
</s>[INST]""" |
|
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"]) |
|
|
|
|
|
llm = Together( |
|
model="mistralai/Mistral-7B-Instruct-v0.2", |
|
temperature=0.5, |
|
max_tokens=1024, |
|
together_api_key=os.getenv("TOGETHER_AI_API"), |
|
) |
|
|
|
|
|
memory = ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True) |
|
|
|
|
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=llm, |
|
memory=memory, |
|
retriever=db_retriever, |
|
combine_docs_chain_kwargs={"prompt": prompt}, |
|
) |
|
|
|
|
|
translator = Translator() |
|
|
|
|
|
class ChatRequest(BaseModel): |
|
question: str |
|
chat_history: str |
|
|
|
|
|
def is_valid_question(question: str) -> bool: |
|
""" |
|
Validate the input question to ensure it is meaningful and related to Indian law or crime. |
|
""" |
|
question = question.strip() |
|
|
|
|
|
if len(question) < 3: |
|
return False |
|
|
|
|
|
if re.match(r'^\d+$', question) or re.match(r'^[^a-zA-Z0-9\s]+$', question): |
|
return False |
|
|
|
|
|
legal_keywords = [ |
|
"IPC", "CrPC", "section", "law", "crime", "penalty", "punishment", |
|
"legal", "court", "justice", "offense", "fraud", "murder", "theft", |
|
"bail", "arrest", "FIR", "judgment", "act", "contract", "constitutional", |
|
"habeas corpus", "petition", "rights", "lawyer", "advocate", "accused", |
|
"penal", "conviction", "sentence", "appeal", "trial", "witness" |
|
] |
|
|
|
|
|
if not any(keyword.lower() in question.lower() for keyword in legal_keywords): |
|
return False |
|
|
|
return True |
|
|
|
|
|
@router.post("/chat/") |
|
async def chat(request: ChatRequest): |
|
try: |
|
|
|
detected_lang = await translator.detect(request.question) |
|
detected_language = detected_lang.lang |
|
|
|
|
|
question_translation = await translator.translate(request.question, src=detected_language, dest="en") |
|
question_in_english = question_translation.text |
|
|
|
|
|
if not is_valid_question(question_in_english): |
|
return { |
|
"answer": "Please provide a valid legal question related to Indian laws.", |
|
"language": LANGUAGES.get(detected_language, "unknown") |
|
} |
|
|
|
|
|
inputs = {"question": question_in_english, "chat_history": request.chat_history} |
|
|
|
|
|
result = qa_chain(inputs) |
|
|
|
|
|
if 'answer' not in result: |
|
raise ValueError("Missing 'answer' key in the result from qa_chain") |
|
|
|
|
|
answer_translation = await translator.translate(result["answer"], src="en", dest=detected_language) |
|
answer_in_original_language = answer_translation.text |
|
|
|
return { |
|
"answer": answer_in_original_language, |
|
"language": LANGUAGES.get(detected_language, "unknown") |
|
} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}") |
|
|
|
|
|
@router.get("/") |
|
async def root(): |
|
return {"message": "LawGPT API is running."} |
|
|
|
|
|
|