import gradio as gr from pydantic import BaseModel, Field from story_beam_search.stories_generator import StoryGenerationSystem genre_choices = [ "children", "mystery", "adventure", "sci-fi", "fantasy", "romance", "comedy", "drama", "horror", ] class InputModel(BaseModel): prompt: str genre: str num_stories: int = Field(3, ge=2, le=7) temperature: float = Field(2.5, ge=0.7, le=3.5) max_length: int = Field(60, ge=30, le=200) def create_story_generation_interface() -> gr.Interface: # Initialize the story generation system system = StoryGenerationSystem() system.initialize() def generate_stories( prompt: str, genre: str, num_stories: int, temperature: float, max_length: int ) -> tuple[str, list[str]]: """ Generate and evaluate stories based on user input. Returns a tuple of (detailed_scores, story_texts). """ # Validate inputs.Gradio seems to validate chioces but not the range of the values input_values = InputModel( prompt=prompt, genre=genre, num_stories=num_stories, temperature=temperature, max_length=max_length, ) # Update beam search config with user parameters system.beam_search.config.temperature = input_values.temperature system.beam_search.config.max_length = input_values.max_length # Generate and evaluate stories ranked_stories = system.generate_and_evaluate( input_values.prompt, input_values.genre, num_stories=input_values.num_stories, ) # Format detailed scores detailed_scores = "" story_texts = [] for i, (story, scores) in enumerate(ranked_stories, 1): detailed_scores += f"Story {i}:\n" detailed_scores += f"Total Score: {scores.total:.3f}\n" detailed_scores += f"Coherence: {scores.coherence:.3f}\n" detailed_scores += f"Fluency: {scores.fluency:.3f}\n" detailed_scores += f"Genre Alignment: {scores.genre_alignment:.3f}\n" detailed_scores += "-" * 50 + "\n" story_texts.append(f"Story {i}:\n{story}\n") return detailed_scores, "\n".join(story_texts) # Define interface components prompt_input = gr.Textbox( label="Story Prompt", placeholder="Enter the beginning of your story...", lines=3, ) genre_input = gr.Dropdown( choices=genre_choices, label="Genre", value="fantasy", ) num_stories_input = gr.Slider( minimum=2, maximum=7, value=3, step=1, label="Number of Stories to Generate" ) temperature_input = gr.Slider( minimum=0.7, maximum=3.5, value=2.5, step=0.1, label="Temperature (Creativity)" ) max_length_input = gr.Slider( minimum=40, maximum=200, value=60, step=20, label="Maximum Length" ) # Output components scores_output = gr.Textbox(label="Detailed Scores", lines=10, interactive=False) stories_output = gr.Textbox(label="Generated Stories", lines=15, interactive=False) # Create the interface interface = gr.Interface( fn=generate_stories, inputs=[ prompt_input, genre_input, num_stories_input, temperature_input, max_length_input, ], outputs=[scores_output, stories_output], title="AI Story Generator", description=""" Generate creative stories using AI! Enter a prompt and choose your preferences. The system will generate multiple stories and evaluate them based on coherence, fluency, and genre alignment. """, examples=[ [ "Once upon a time in a magical forest, the trees whispered secrets, and moonlight revealed hidden paths to a realm where time stood still.", "fantasy", 3, 1.8, 150, ], [ "The detective knelt beside the bloodstained carpet, her gaze sharp as she traced the faint outline of a shoeprint.", "mystery", 3, 2.7, 200, ], ], theme=gr.themes.Soft(), ) return interface if __name__ == "__main__": # Create and launch the interface interface = create_story_generation_interface() interface.launch(show_error=True)