ruslanmv commited on
Commit
b3fce51
·
verified ·
1 Parent(s): cdbc5ce

Create enhance.py

Browse files
Files changed (1) hide show
  1. enhance.py +82 -0
enhance.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import requests
3
+ import json
4
+
5
+ def generate(message, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.0):
6
+ """
7
+ Generates an enhanced prompt using the streaming inference mechanism from a Hugging Face API endpoint.
8
+ This function formats the prompt with a system instruction, sends a streaming request to the API,
9
+ and yields the accumulated text as tokens are received.
10
+
11
+ Parameters:
12
+ message (str): The user's input prompt.
13
+ max_new_tokens (int): The maximum number of tokens to generate.
14
+ temperature (float): Sampling temperature.
15
+ top_p (float): Nucleus sampling parameter.
16
+ repetition_penalty (float): Penalty factor for repetition (not used in the payload but kept for API consistency).
17
+
18
+ Yields:
19
+ str: The accumulated generated text as it streams in.
20
+ """
21
+ # Define the system prompt.
22
+ SYSTEM_PROMPT = (
23
+ "You are a prompt enhancer and your work is to enhance the given prompt under 100 words "
24
+ "without changing the essence, only write the enhanced prompt and nothing else."
25
+ )
26
+ # Format the prompt with a timestamp for uniqueness.
27
+ timestamp = time.time()
28
+ formatted_prompt = (
29
+ f"<s>[INST] SYSTEM: {SYSTEM_PROMPT} [/INST]"
30
+ f"[INST] {message} {timestamp} [/INST]"
31
+ )
32
+
33
+ # Define the API endpoint and headers.
34
+ api_url = "https://ruslanmv-hf-llm-api.hf.space/api/v1/chat/completions"
35
+ headers = {"Content-Type": "application/json"}
36
+
37
+ # Build the payload for the inference request.
38
+ payload = {
39
+ "model": "mixtral-8x7b",
40
+ "messages": [{"role": "user", "content": formatted_prompt}],
41
+ "temperature": temperature,
42
+ "top_p": top_p,
43
+ "max_tokens": max_new_tokens,
44
+ "use_cache": False,
45
+ "stream": True
46
+ }
47
+
48
+ try:
49
+ response = requests.post(api_url, headers=headers, json=payload, stream=True)
50
+ response.raise_for_status()
51
+ full_output = ""
52
+
53
+ # Process the streaming response line by line.
54
+ for line in response.iter_lines():
55
+ if not line:
56
+ continue
57
+
58
+ decoded_line = line.decode("utf-8").strip()
59
+ # Remove the "data:" prefix if present.
60
+ if decoded_line.startswith("data:"):
61
+ decoded_line = decoded_line[len("data:"):].strip()
62
+
63
+ # Check if the stream is finished.
64
+ if decoded_line == "[DONE]":
65
+ break
66
+
67
+ try:
68
+ json_data = json.loads(decoded_line)
69
+ for choice in json_data.get("choices", []):
70
+ delta = choice.get("delta", {})
71
+ content = delta.get("content", "")
72
+ full_output += content
73
+ yield full_output # Yield the accumulated text so far.
74
+
75
+ # If the finish reason is provided, stop further streaming.
76
+ if choice.get("finish_reason") == "stop":
77
+ return
78
+ except json.JSONDecodeError:
79
+ # If a line is not valid JSON, skip it.
80
+ continue
81
+ except requests.exceptions.RequestException as e:
82
+ yield f"Error during generation: {str(e)}"