File size: 5,759 Bytes
3bf964b
bb72137
 
 
4786462
6533546
bb72137
 
 
 
 
 
4786462
 
bb72137
 
4786462
bb72137
 
 
 
3bf964b
3d29999
5edb1a7
3bf964b
 
 
 
2f5e3d5
5edb1a7
4786462
bb72137
 
 
 
 
 
 
75e1f9d
bb72137
4786462
bb72137
 
75e1f9d
bb72137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4786462
bb72137
 
 
 
 
 
 
 
 
 
 
 
 
4786462
 
 
bb72137
 
 
 
 
4786462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb72137
4786462
bb72137
 
4786462
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb72137
4786462
 
 
 
 
 
 
 
 
 
 
 
 
bb72137
 
 
 
4786462
bb72137
5edb1a7
3bf964b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from fastapi import FastAPI, APIRouter, HTTPException
from pydantic import BaseModel
from pathlib import Path
import os
import re
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain_together import Together
from langchain.memory import ConversationBufferWindowMemory
from langchain.chains import ConversationalRetrievalChain
from langdetect import detect
from googletrans import Translator, LANGUAGES

# Set the API key for Together.ai
os.environ["TOGETHER_AI_API"] = "1c27fe0df51a29edee1bec6b4b648b436cc80cf4ccc36f56de17272d9e663cbd"

# Ensure proper cache directory is available for models
os.environ['TRANSFORMERS_CACHE'] = '/tmp/cache'

# Initialize FastAPI app
app = FastAPI()

# Initialize FastAPI Router
router = APIRouter()

# Include the router in the app
app.include_router(router, prefix="/lawgpt")

bot_name = "LawGPT"

# Lazy loading of large models (only load embeddings and index when required)
embeddings = HuggingFaceEmbeddings(
    model_name="nomic-ai/nomic-embed-text-v1",
    model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
)

index_path = Path("models/index.faiss")
if not index_path.exists():
    raise FileNotFoundError("FAISS index not found. Please generate it and place it in 'models'.")

# Load the FAISS index
db = FAISS.load_local("models", embeddings, allow_dangerous_deserialization=True)
db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 4})

# Define the prompt template for the legal chatbot
prompt_template = """<s>[INST]This is a chat template and as a legal chatbot specializing in Indian Penal Code queries, your objective is to provide accurate and concise information.
CONTEXT: {context}
CHAT HISTORY: {chat_history}
QUESTION: {question}
ANSWER:
</s>[INST]"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question", "chat_history"])

# Set up the LLM (Large Language Model) for the chatbot
llm = Together(
    model="mistralai/Mistral-7B-Instruct-v0.2",
    temperature=0.5,
    max_tokens=1024,
    together_api_key=os.getenv("TOGETHER_AI_API"),
)

# Set up memory for conversational context
memory = ConversationBufferWindowMemory(k=2, memory_key="chat_history", return_messages=True)

# Create the conversational retrieval chain with the LLM and retriever
qa_chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    memory=memory,
    retriever=db_retriever,
    combine_docs_chain_kwargs={"prompt": prompt},
)

# Translator instance (sync version)
translator = Translator()

# Input schema for chat requests
class ChatRequest(BaseModel):
    question: str
    chat_history: str

# Function to validate the input question
def is_valid_question(question: str) -> bool:
    """
    Validate the input question to ensure it is meaningful and related to Indian law or crime.
    """
    question = question.strip()

    # Reject if the question is too short
    if len(question) < 3:
        return False

    # Reject if the question contains only numbers or symbols
    if re.match(r'^\d+$', question) or re.match(r'^[^a-zA-Z0-9\s]+$', question):
        return False

    # Define keywords related to Indian law and crime
    legal_keywords = [
        "IPC", "CrPC", "section", "law", "crime", "penalty", "punishment",
        "legal", "court", "justice", "offense", "fraud", "murder", "theft",
        "bail", "arrest", "FIR", "judgment", "act", "contract", "constitutional",
        "habeas corpus", "petition", "rights", "lawyer", "advocate", "accused",
        "penal", "conviction", "sentence", "appeal", "trial", "witness"
    ]

    # Check if the question contains at least one legal keyword
    if not any(keyword.lower() in question.lower() for keyword in legal_keywords):
        return False

    return True

# POST endpoint to handle chat requests
@router.post("/chat/")
async def chat(request: ChatRequest):
    try:
        # Detect language
        detected_lang = await translator.detect(request.question)
        detected_language = detected_lang.lang

        # Translate question to English
        question_translation = await translator.translate(request.question, src=detected_language, dest="en")
        question_in_english = question_translation.text

        # Validate translated question
        if not is_valid_question(question_in_english):
            return {
                "answer": "Please provide a valid legal question related to Indian laws.",
                "language": LANGUAGES.get(detected_language, "unknown")
            }

        # Prepare input for LLM
        inputs = {"question": question_in_english, "chat_history": request.chat_history}

        # Run LLM retrieval chain
        result = qa_chain(inputs)

        # Ensure response contains an answer
        if 'answer' not in result:
            raise ValueError("Missing 'answer' key in the result from qa_chain")

        # Translate response back to original language
        answer_translation = await translator.translate(result["answer"], src="en", dest=detected_language)
        answer_in_original_language = answer_translation.text

        return {
            "answer": answer_in_original_language,
            "language": LANGUAGES.get(detected_language, "unknown")
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")

# GET endpoint to check if the API is running
@router.get("/")
async def root():
    return {"message": "LawGPT API is running."}

# Make sure the router is included and FastAPI will handle the paths.