import time import requests import json def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0): """ Generates an enhanced prompt using the streaming inference mechanism from a Hugging Face API endpoint. This function formats the prompt with a system instruction, sends a streaming request to the API, and yields the accumulated text as tokens are received. Parameters: message (str): The user's input prompt. max_new_tokens (int): The maximum number of tokens to generate. temperature (float): Sampling temperature. top_p (float): Nucleus sampling parameter. repetition_penalty (float): Penalty factor for repetition (not used in the payload but kept for API consistency). Yields: str: The accumulated generated text as it streams in. """ # Define the system prompt. SYSTEM_PROMPT = ( "You are a prompt enhancer and your work is to enhance the given prompt under 100 words " "without changing the essence, only write the enhanced prompt and nothing else." ) # Format the prompt with a timestamp for uniqueness. timestamp = time.time() formatted_prompt = ( f"[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]" f"[INST] {message} {timestamp} [/INST]" ) # Define the API endpoint and headers. api_url = "https://ruslanmv-hf-llm-api.hf.space/api/v1/chat/completions" headers = {"Content-Type": "application/json"} # Build the payload for the inference request. payload = { "model": "mixtral-8x7b", "messages": [{"role": "user", "content": formatted_prompt}], "temperature": temperature, "top_p": top_p, "max_tokens": max_new_tokens, "use_cache": False, "stream": True } try: response = requests.post(api_url, headers=headers, json=payload, stream=True) response.raise_for_status() full_output = "" # Process the streaming response line by line. for line in response.iter_lines(): if not line: continue decoded_line = line.decode("utf-8").strip() # Remove the "data:" prefix if present. if decoded_line.startswith("data:"): decoded_line = decoded_line[len("data:"):].strip() # Check if the stream is finished. if decoded_line == "[DONE]": break try: json_data = json.loads(decoded_line) for choice in json_data.get("choices", []): delta = choice.get("delta", {}) content = delta.get("content", "") full_output += content yield full_output # Yield the accumulated text so far. # If the finish reason is provided, stop further streaming. if choice.get("finish_reason") == "stop": return except json.JSONDecodeError: # If a line is not valid JSON, skip it. continue except requests.exceptions.RequestException as e: yield f"Error during generation: {str(e)}"