Kokoro-Conversational / speech_to_speech.py
Abdullah Al Asif
--base
78cb487
import msvcrt
import traceback
import time
import requests
import time
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from src.utils.config import settings
from src.utils import (
VoiceGenerator,
get_ai_response,
play_audio_with_interrupt,
init_vad_pipeline,
detect_speech_segments,
record_continuous_audio,
check_for_speech,
transcribe_audio,
)
from src.utils.audio_queue import AudioGenerationQueue
from src.utils.llm import parse_stream_chunk
import threading
from src.utils.text_chunker import TextChunker
settings.setup_directories()
timing_info = {
"vad_start": None,
"transcription_start": None,
"llm_first_token": None,
"audio_queued": None,
"first_audio_play": None,
"playback_start": None,
"end": None,
"transcription_duration": None,
}
def process_input(
session: requests.Session,
user_input: str,
messages: list,
generator: VoiceGenerator,
speed: float,
) -> tuple[bool, None]:
"""Processes user input, generates a response, and handles audio output.
Args:
session (requests.Session): The requests session to use.
user_input (str): The user's input text.
messages (list): The list of messages to send to the LLM.
generator (VoiceGenerator): The voice generator object.
speed (float): The playback speed.
Returns:
tuple[bool, None]: A tuple containing a boolean indicating if the process was interrupted and None.
"""
global timing_info
timing_info = {k: None for k in timing_info}
timing_info["vad_start"] = time.perf_counter()
messages.append({"role": "user", "content": user_input})
print("\nThinking...")
start_time = time.time()
try:
response_stream = get_ai_response(
session=session,
messages=messages,
llm_model=settings.LLM_MODEL,
llm_url=settings.OLLAMA_URL,
max_tokens=settings.MAX_TOKENS,
stream=True,
)
if not response_stream:
print("Failed to get AI response stream.")
return False, None
audio_queue = AudioGenerationQueue(generator, speed)
audio_queue.start()
chunker = TextChunker()
complete_response = []
playback_thread = threading.Thread(
target=lambda: audio_playback_worker(audio_queue)
)
playback_thread.daemon = True
playback_thread.start()
for chunk in response_stream:
data = parse_stream_chunk(chunk)
if not data or "choices" not in data:
continue
choice = data["choices"][0]
if "delta" in choice and "content" in choice["delta"]:
content = choice["delta"]["content"]
if content:
if not timing_info["llm_first_token"]:
timing_info["llm_first_token"] = time.perf_counter()
print(content, end="", flush=True)
chunker.current_text.append(content)
text = "".join(chunker.current_text)
if chunker.should_process(text):
if not timing_info["audio_queued"]:
timing_info["audio_queued"] = time.perf_counter()
remaining = chunker.process(text, audio_queue)
chunker.current_text = [remaining]
complete_response.append(text[: len(text) - len(remaining)])
if choice.get("finish_reason") == "stop":
final_text = "".join(chunker.current_text).strip()
if final_text:
chunker.process(final_text, audio_queue)
complete_response.append(final_text)
break
messages.append({"role": "assistant", "content": " ".join(complete_response)})
print()
time.sleep(0.1)
audio_queue.stop()
playback_thread.join()
def playback_wrapper():
timing_info["playback_start"] = time.perf_counter()
result = audio_playback_worker(audio_queue)
return result
playback_thread = threading.Thread(target=playback_wrapper)
timing_info["end"] = time.perf_counter()
print_timing_chart(timing_info)
return False, None
except Exception as e:
print(f"\nError during streaming: {str(e)}")
if "audio_queue" in locals():
audio_queue.stop()
return False, None
def audio_playback_worker(audio_queue) -> tuple[bool, None]:
"""Manages audio playback in a separate thread, handling interruptions.
Args:
audio_queue (AudioGenerationQueue): The audio queue object.
Returns:
tuple[bool, None]: A tuple containing a boolean indicating if the playback was interrupted and the interrupt audio data.
"""
global timing_info
was_interrupted = False
interrupt_audio = None
try:
while True:
speech_detected, audio_data = check_for_speech()
if speech_detected:
was_interrupted = True
interrupt_audio = audio_data
break
audio_data, _ = audio_queue.get_next_audio()
if audio_data is not None:
if not timing_info["first_audio_play"]:
timing_info["first_audio_play"] = time.perf_counter()
was_interrupted, interrupt_data = play_audio_with_interrupt(audio_data)
if was_interrupted:
interrupt_audio = interrupt_data
break
else:
time.sleep(settings.PLAYBACK_DELAY)
if (
not audio_queue.is_running
and audio_queue.sentence_queue.empty()
and audio_queue.audio_queue.empty()
):
break
except Exception as e:
print(f"Error in audio playback: {str(e)}")
return was_interrupted, interrupt_audio
def main():
"""Main function to run the voice chat bot."""
with requests.Session() as session:
try:
session = requests.Session()
generator = VoiceGenerator(settings.MODELS_DIR, settings.VOICES_DIR)
messages = [{"role": "system", "content": settings.DEFAULT_SYSTEM_PROMPT}]
print("\nInitializing Whisper model...")
whisper_processor = WhisperProcessor.from_pretrained(settings.WHISPER_MODEL)
whisper_model = WhisperForConditionalGeneration.from_pretrained(
settings.WHISPER_MODEL
)
print("\nInitializing Voice Activity Detection...")
vad_pipeline = init_vad_pipeline(settings.HUGGINGFACE_TOKEN)
print("\n=== Voice Chat Bot Initializing ===")
print("Device being used:", generator.device)
print("\nInitializing voice generator...")
result = generator.initialize(settings.TTS_MODEL, settings.VOICE_NAME)
print(result)
speed = settings.SPEED
try:
print("\nWarming up the LLM model...")
health = session.get("http://localhost:11434", timeout=3)
if health.status_code != 200:
print("Ollama not running! Start it first.")
return
response_stream = get_ai_response(
session=session,
messages=[
{"role": "system", "content": settings.DEFAULT_SYSTEM_PROMPT},
{"role": "user", "content": "Hi!"},
],
llm_model=settings.LLM_MODEL,
llm_url=settings.OLLAMA_URL,
max_tokens=settings.MAX_TOKENS,
stream=False,
)
if not response_stream:
print("Failed to initialized the AI model!")
return
except requests.RequestException as e:
print(f"Warmup failed: {str(e)}")
print("\n\n=== Voice Chat Bot Ready ===")
print("The bot is now listening for speech.")
print("Just start speaking, and I'll respond automatically!")
print("You can interrupt me anytime by starting to speak.")
while True:
try:
if msvcrt.kbhit():
user_input = input("\nYou (text): ").strip()
if user_input.lower() == "quit":
print("Goodbye!")
break
audio_data = record_continuous_audio()
if audio_data is not None:
speech_segments = detect_speech_segments(
vad_pipeline, audio_data
)
if speech_segments is not None:
print("\nTranscribing detected speech...")
timing_info["transcription_start"] = time.perf_counter()
user_input = transcribe_audio(
whisper_processor, whisper_model, speech_segments
)
timing_info["transcription_duration"] = (
time.perf_counter() - timing_info["transcription_start"]
)
if user_input.strip():
print(f"You (voice): {user_input}")
was_interrupted, speech_data = process_input(
session, user_input, messages, generator, speed
)
if was_interrupted and speech_data is not None:
speech_segments = detect_speech_segments(
vad_pipeline, speech_data
)
if speech_segments is not None:
print("\nTranscribing interrupted speech...")
user_input = transcribe_audio(
whisper_processor,
whisper_model,
speech_segments,
)
if user_input.strip():
print(f"You (voice): {user_input}")
process_input(
session,
user_input,
messages,
generator,
speed,
)
else:
print("No clear speech detected, please try again.")
if session is not None:
session.headers.update({"Connection": "keep-alive"})
if hasattr(session, "connection_pool"):
session.connection_pool.clear()
except KeyboardInterrupt:
print("\nStopping...")
break
except Exception as e:
print(f"Error: {str(e)}")
continue
except Exception as e:
print(f"Error: {str(e)}")
print("\nFull traceback:")
traceback.print_exc()
def print_timing_chart(metrics):
"""Prints timing chart from global metrics"""
base_time = metrics["vad_start"]
events = [
("User stopped speaking", metrics["vad_start"]),
("VAD started", metrics["vad_start"]),
("Transcription started", metrics["transcription_start"]),
("LLM first token", metrics["llm_first_token"]),
("Audio queued", metrics["audio_queued"]),
("First audio played", metrics["first_audio_play"]),
("Playback started", metrics["playback_start"]),
("End-to-end response", metrics["end"]),
]
print("\nTiming Chart:")
print(f"{'Event':<25} | {'Time (s)':>9} | {'Δ+':>6}")
print("-" * 45)
prev_time = base_time
for name, t in events:
if t is None:
continue
elapsed = t - base_time
delta = t - prev_time
print(f"{name:<25} | {elapsed:9.2f} | {delta:6.2f}")
prev_time = t
if __name__ == "__main__":
main()