|
from typing import Type, Optional |
|
from pydantic import BaseModel |
|
from langgraph.graph import StateGraph, START, END |
|
from typing import TypedDict |
|
import constants |
|
from langchain_groq import ChatGroq |
|
|
|
|
|
|
|
class State(TypedDict): |
|
messages: list |
|
output: Optional[BaseModel] |
|
|
|
|
|
|
|
class StructuredOutputExtractor: |
|
def __init__(self, response_schema: Type[BaseModel]): |
|
""" |
|
Initializes the extractor for any given structured output model. |
|
|
|
:param response_schema: Pydantic model class used for structured output extraction |
|
""" |
|
self.response_schema = response_schema |
|
|
|
|
|
|
|
self.llm = ChatGroq(model="deepseek-r1-distill-llama-70b") |
|
|
|
|
|
self.structured_llm = self.llm.with_structured_output(response_schema) |
|
|
|
|
|
self._build_graph() |
|
|
|
def _build_graph(self): |
|
""" |
|
Build the LangGraph computational graph for structured extraction. |
|
""" |
|
graph_builder = StateGraph(State) |
|
|
|
|
|
graph_builder.add_node("extract", self._extract_structured_info) |
|
graph_builder.add_edge(START, "extract") |
|
graph_builder.add_edge("extract", END) |
|
|
|
self.graph = graph_builder.compile() |
|
|
|
def _extract_structured_info(self, state: dict): |
|
""" |
|
Extract structured information using the specified response model. |
|
|
|
:param state: Current graph state |
|
:return: Updated state with structured output |
|
""" |
|
query = state['messages'][-1].content |
|
print(f"Processing query: {query}") |
|
try: |
|
|
|
output = self.structured_llm.invoke(query) |
|
|
|
return {"output": output} |
|
except Exception as e: |
|
print(f"Error during extraction: {e}") |
|
return {"output": None} |
|
|
|
def extract(self, query: str) -> Optional[BaseModel]: |
|
""" |
|
Public method to extract structured information. |
|
|
|
:param query: Input query for structured output extraction |
|
:return: Structured model object or None |
|
""" |
|
from langchain_core.messages import SystemMessage |
|
|
|
result = self.graph.invoke({ |
|
"messages": [SystemMessage(content=query)] |
|
}) |
|
|
|
result = result.get('output') |
|
return result |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
class Movie(BaseModel): |
|
title: str |
|
year: int |
|
genre: str |
|
rating: Optional[float] = None |
|
actors: list[str] = [] |
|
|
|
|
|
|
|
extractor = StructuredOutputExtractor(response_schema=Movie) |
|
|
|
query = "Tell me about the movie Inception. Provide details about its title, year, genre, rating, and main actors." |
|
|
|
result = extractor.extract(query) |
|
print(type(result)) |
|
if result: |
|
print(result) |