import msvcrt import traceback import time import requests import time from transformers import WhisperProcessor, WhisperForConditionalGeneration from src.utils.config import settings from src.utils import ( VoiceGenerator, get_ai_response, play_audio_with_interrupt, init_vad_pipeline, detect_speech_segments, record_continuous_audio, check_for_speech, transcribe_audio, ) from src.utils.audio_queue import AudioGenerationQueue from src.utils.llm import parse_stream_chunk import threading from src.utils.text_chunker import TextChunker settings.setup_directories() timing_info = { "vad_start": None, "transcription_start": None, "llm_first_token": None, "audio_queued": None, "first_audio_play": None, "playback_start": None, "end": None, "transcription_duration": None, } def process_input( session: requests.Session, user_input: str, messages: list, generator: VoiceGenerator, speed: float, ) -> tuple[bool, None]: """Processes user input, generates a response, and handles audio output. Args: session (requests.Session): The requests session to use. user_input (str): The user's input text. messages (list): The list of messages to send to the LLM. generator (VoiceGenerator): The voice generator object. speed (float): The playback speed. Returns: tuple[bool, None]: A tuple containing a boolean indicating if the process was interrupted and None. """ global timing_info timing_info = {k: None for k in timing_info} timing_info["vad_start"] = time.perf_counter() messages.append({"role": "user", "content": user_input}) print("\nThinking...") start_time = time.time() try: response_stream = get_ai_response( session=session, messages=messages, llm_model=settings.LLM_MODEL, llm_url=settings.OLLAMA_URL, max_tokens=settings.MAX_TOKENS, stream=True, ) if not response_stream: print("Failed to get AI response stream.") return False, None audio_queue = AudioGenerationQueue(generator, speed) audio_queue.start() chunker = TextChunker() complete_response = [] playback_thread = threading.Thread( target=lambda: audio_playback_worker(audio_queue) ) playback_thread.daemon = True playback_thread.start() for chunk in response_stream: data = parse_stream_chunk(chunk) if not data or "choices" not in data: continue choice = data["choices"][0] if "delta" in choice and "content" in choice["delta"]: content = choice["delta"]["content"] if content: if not timing_info["llm_first_token"]: timing_info["llm_first_token"] = time.perf_counter() print(content, end="", flush=True) chunker.current_text.append(content) text = "".join(chunker.current_text) if chunker.should_process(text): if not timing_info["audio_queued"]: timing_info["audio_queued"] = time.perf_counter() remaining = chunker.process(text, audio_queue) chunker.current_text = [remaining] complete_response.append(text[: len(text) - len(remaining)]) if choice.get("finish_reason") == "stop": final_text = "".join(chunker.current_text).strip() if final_text: chunker.process(final_text, audio_queue) complete_response.append(final_text) break messages.append({"role": "assistant", "content": " ".join(complete_response)}) print() time.sleep(0.1) audio_queue.stop() playback_thread.join() def playback_wrapper(): timing_info["playback_start"] = time.perf_counter() result = audio_playback_worker(audio_queue) return result playback_thread = threading.Thread(target=playback_wrapper) timing_info["end"] = time.perf_counter() print_timing_chart(timing_info) return False, None except Exception as e: print(f"\nError during streaming: {str(e)}") if "audio_queue" in locals(): audio_queue.stop() return False, None def audio_playback_worker(audio_queue) -> tuple[bool, None]: """Manages audio playback in a separate thread, handling interruptions. Args: audio_queue (AudioGenerationQueue): The audio queue object. Returns: tuple[bool, None]: A tuple containing a boolean indicating if the playback was interrupted and the interrupt audio data. """ global timing_info was_interrupted = False interrupt_audio = None try: while True: speech_detected, audio_data = check_for_speech() if speech_detected: was_interrupted = True interrupt_audio = audio_data break audio_data, _ = audio_queue.get_next_audio() if audio_data is not None: if not timing_info["first_audio_play"]: timing_info["first_audio_play"] = time.perf_counter() was_interrupted, interrupt_data = play_audio_with_interrupt(audio_data) if was_interrupted: interrupt_audio = interrupt_data break else: time.sleep(settings.PLAYBACK_DELAY) if ( not audio_queue.is_running and audio_queue.sentence_queue.empty() and audio_queue.audio_queue.empty() ): break except Exception as e: print(f"Error in audio playback: {str(e)}") return was_interrupted, interrupt_audio def main(): """Main function to run the voice chat bot.""" with requests.Session() as session: try: session = requests.Session() generator = VoiceGenerator(settings.MODELS_DIR, settings.VOICES_DIR) messages = [{"role": "system", "content": settings.DEFAULT_SYSTEM_PROMPT}] print("\nInitializing Whisper model...") whisper_processor = WhisperProcessor.from_pretrained(settings.WHISPER_MODEL) whisper_model = WhisperForConditionalGeneration.from_pretrained( settings.WHISPER_MODEL ) print("\nInitializing Voice Activity Detection...") vad_pipeline = init_vad_pipeline(settings.HUGGINGFACE_TOKEN) print("\n=== Voice Chat Bot Initializing ===") print("Device being used:", generator.device) print("\nInitializing voice generator...") result = generator.initialize(settings.TTS_MODEL, settings.VOICE_NAME) print(result) speed = settings.SPEED try: print("\nWarming up the LLM model...") health = session.get("http://localhost:11434", timeout=3) if health.status_code != 200: print("Ollama not running! Start it first.") return response_stream = get_ai_response( session=session, messages=[ {"role": "system", "content": settings.DEFAULT_SYSTEM_PROMPT}, {"role": "user", "content": "Hi!"}, ], llm_model=settings.LLM_MODEL, llm_url=settings.OLLAMA_URL, max_tokens=settings.MAX_TOKENS, stream=False, ) if not response_stream: print("Failed to initialized the AI model!") return except requests.RequestException as e: print(f"Warmup failed: {str(e)}") print("\n\n=== Voice Chat Bot Ready ===") print("The bot is now listening for speech.") print("Just start speaking, and I'll respond automatically!") print("You can interrupt me anytime by starting to speak.") while True: try: if msvcrt.kbhit(): user_input = input("\nYou (text): ").strip() if user_input.lower() == "quit": print("Goodbye!") break audio_data = record_continuous_audio() if audio_data is not None: speech_segments = detect_speech_segments( vad_pipeline, audio_data ) if speech_segments is not None: print("\nTranscribing detected speech...") timing_info["transcription_start"] = time.perf_counter() user_input = transcribe_audio( whisper_processor, whisper_model, speech_segments ) timing_info["transcription_duration"] = ( time.perf_counter() - timing_info["transcription_start"] ) if user_input.strip(): print(f"You (voice): {user_input}") was_interrupted, speech_data = process_input( session, user_input, messages, generator, speed ) if was_interrupted and speech_data is not None: speech_segments = detect_speech_segments( vad_pipeline, speech_data ) if speech_segments is not None: print("\nTranscribing interrupted speech...") user_input = transcribe_audio( whisper_processor, whisper_model, speech_segments, ) if user_input.strip(): print(f"You (voice): {user_input}") process_input( session, user_input, messages, generator, speed, ) else: print("No clear speech detected, please try again.") if session is not None: session.headers.update({"Connection": "keep-alive"}) if hasattr(session, "connection_pool"): session.connection_pool.clear() except KeyboardInterrupt: print("\nStopping...") break except Exception as e: print(f"Error: {str(e)}") continue except Exception as e: print(f"Error: {str(e)}") print("\nFull traceback:") traceback.print_exc() def print_timing_chart(metrics): """Prints timing chart from global metrics""" base_time = metrics["vad_start"] events = [ ("User stopped speaking", metrics["vad_start"]), ("VAD started", metrics["vad_start"]), ("Transcription started", metrics["transcription_start"]), ("LLM first token", metrics["llm_first_token"]), ("Audio queued", metrics["audio_queued"]), ("First audio played", metrics["first_audio_play"]), ("Playback started", metrics["playback_start"]), ("End-to-end response", metrics["end"]), ] print("\nTiming Chart:") print(f"{'Event':<25} | {'Time (s)':>9} | {'Δ+':>6}") print("-" * 45) prev_time = base_time for name, t in events: if t is None: continue elapsed = t - base_time delta = t - prev_time print(f"{name:<25} | {elapsed:9.2f} | {delta:6.2f}") prev_time = t if __name__ == "__main__": main()