audio-to-video-generator / structured_output_extractor.py
habib926653's picture
Update structured_output_extractor.py
509509a verified
from typing import Type, Optional
from pydantic import BaseModel
from langgraph.graph import StateGraph, START, END
from typing import TypedDict
import constants # Assuming constants.py holds LLM provider configurations
from langchain_groq import ChatGroq
# Define the State structure (similar to previous definition)
class State(TypedDict):
messages: list
output: Optional[BaseModel]
# Generic Pydantic model-based structured output extractor
class StructuredOutputExtractor:
def __init__(self, response_schema: Type[BaseModel]):
"""
Initializes the extractor for any given structured output model.
:param response_schema: Pydantic model class used for structured output extraction
"""
self.response_schema = response_schema
# Initialize language model (provider and API keys come from constants.py)
# self.llm = ChatGroq(model="llama-3.3-70b-versatile") # token limit 100k tokens
self.llm = ChatGroq(model="deepseek-r1-distill-llama-70b") # currently no limit per day
# Bind the model with structured output capability
self.structured_llm = self.llm.with_structured_output(response_schema)
# Build the graph for structured output
self._build_graph()
def _build_graph(self):
"""
Build the LangGraph computational graph for structured extraction.
"""
graph_builder = StateGraph(State)
# Add nodes and edges for structured output
graph_builder.add_node("extract", self._extract_structured_info)
graph_builder.add_edge(START, "extract")
graph_builder.add_edge("extract", END)
self.graph = graph_builder.compile()
def _extract_structured_info(self, state: dict):
"""
Extract structured information using the specified response model.
:param state: Current graph state
:return: Updated state with structured output
"""
query = state['messages'][-1].content
print(f"Processing query: {query}")
try:
# Extract details using the structured model
output = self.structured_llm.invoke(query)
# Return the structured response
return {"output": output}
except Exception as e:
print(f"Error during extraction: {e}")
return {"output": None}
def extract(self, query: str) -> Optional[BaseModel]:
"""
Public method to extract structured information.
:param query: Input query for structured output extraction
:return: Structured model object or None
"""
from langchain_core.messages import SystemMessage
result = self.graph.invoke({
"messages": [SystemMessage(content=query)]
})
# Return the structured model response, if available
result = result.get('output')
return result
if __name__ == '__main__':
# Example Pydantic model (e.g., Movie)
class Movie(BaseModel):
title: str
year: int
genre: str
rating: Optional[float] = None
actors: list[str] = []
# Example usage with a generic structured extractor
extractor = StructuredOutputExtractor(response_schema=Movie)
query = "Tell me about the movie Inception. Provide details about its title, year, genre, rating, and main actors."
result = extractor.extract(query)
print(type(result))
if result:
print(result)