Spaces:
Running
Running
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. | |
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, | |
# least connections, or least response time for better resource utilization and performance optimization. | |
import asyncio | |
import json | |
import logging | |
import os | |
import random | |
import re | |
import time | |
from typing import Optional, Union | |
from urllib.parse import urlparse | |
import aiohttp | |
from aiocache import cached | |
import requests | |
from fastapi import ( | |
Depends, | |
FastAPI, | |
File, | |
HTTPException, | |
Request, | |
UploadFile, | |
APIRouter, | |
) | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from pydantic import BaseModel, ConfigDict | |
from starlette.background import BackgroundTask | |
from open_webui.models.models import Models | |
from open_webui.utils.misc import ( | |
calculate_sha256, | |
) | |
from open_webui.utils.payload import ( | |
apply_model_params_to_body_ollama, | |
apply_model_params_to_body_openai, | |
apply_model_system_prompt_to_body, | |
) | |
from open_webui.utils.auth import get_admin_user, get_verified_user | |
from open_webui.utils.access_control import has_access | |
from open_webui.config import ( | |
UPLOAD_DIR, | |
) | |
from open_webui.env import ( | |
ENV, | |
SRC_LOG_LEVELS, | |
AIOHTTP_CLIENT_TIMEOUT, | |
AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST, | |
BYPASS_MODEL_ACCESS_CONTROL, | |
) | |
from open_webui.constants import ERROR_MESSAGES | |
log = logging.getLogger(__name__) | |
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) | |
########################################## | |
# | |
# Utility functions | |
# | |
########################################## | |
async def send_get_request(url, key=None): | |
timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) | |
try: | |
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
async with session.get( | |
url, headers={**({"Authorization": f"Bearer {key}"} if key else {})} | |
) as response: | |
return await response.json() | |
except Exception as e: | |
# Handle connection error here | |
log.error(f"Connection error: {e}") | |
return None | |
async def cleanup_response( | |
response: Optional[aiohttp.ClientResponse], | |
session: Optional[aiohttp.ClientSession], | |
): | |
if response: | |
response.close() | |
if session: | |
await session.close() | |
async def send_post_request( | |
url: str, | |
payload: Union[str, bytes], | |
stream: bool = True, | |
key: Optional[str] = None, | |
content_type: Optional[str] = None, | |
): | |
r = None | |
try: | |
session = aiohttp.ClientSession( | |
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) | |
) | |
r = await session.post( | |
url, | |
data=payload, | |
headers={ | |
"Content-Type": "application/json", | |
**({"Authorization": f"Bearer {key}"} if key else {}), | |
}, | |
) | |
r.raise_for_status() | |
if stream: | |
response_headers = dict(r.headers) | |
if content_type: | |
response_headers["Content-Type"] = content_type | |
return StreamingResponse( | |
r.content, | |
status_code=r.status, | |
headers=response_headers, | |
background=BackgroundTask( | |
cleanup_response, response=r, session=session | |
), | |
) | |
else: | |
res = await r.json() | |
await cleanup_response(r, session) | |
return res | |
except Exception as e: | |
detail = None | |
if r is not None: | |
try: | |
res = await r.json() | |
if "error" in res: | |
detail = f"Ollama: {res.get('error', 'Unknown error')}" | |
except Exception: | |
detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status if r else 500, | |
detail=detail if detail else "Open WebUI: Server Connection Error", | |
) | |
def get_api_key(url, configs): | |
parsed_url = urlparse(url) | |
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" | |
return configs.get(base_url, {}).get("key", None) | |
########################################## | |
# | |
# API routes | |
# | |
########################################## | |
router = APIRouter() | |
async def get_status(): | |
return {"status": True} | |
class ConnectionVerificationForm(BaseModel): | |
url: str | |
key: Optional[str] = None | |
async def verify_connection( | |
form_data: ConnectionVerificationForm, user=Depends(get_admin_user) | |
): | |
url = form_data.url | |
key = form_data.key | |
async with aiohttp.ClientSession( | |
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST) | |
) as session: | |
try: | |
async with session.get( | |
f"{url}/api/version", | |
headers={**({"Authorization": f"Bearer {key}"} if key else {})}, | |
) as r: | |
if r.status != 200: | |
detail = f"HTTP Error: {r.status}" | |
res = await r.json() | |
if "error" in res: | |
detail = f"External Error: {res['error']}" | |
raise Exception(detail) | |
data = await r.json() | |
return data | |
except aiohttp.ClientError as e: | |
log.exception(f"Client error: {str(e)}") | |
raise HTTPException( | |
status_code=500, detail="Open WebUI: Server Connection Error" | |
) | |
except Exception as e: | |
log.exception(f"Unexpected error: {e}") | |
error_detail = f"Unexpected error: {str(e)}" | |
raise HTTPException(status_code=500, detail=error_detail) | |
async def get_config(request: Request, user=Depends(get_admin_user)): | |
return { | |
"ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, | |
"OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, | |
"OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, | |
} | |
class OllamaConfigForm(BaseModel): | |
ENABLE_OLLAMA_API: Optional[bool] = None | |
OLLAMA_BASE_URLS: list[str] | |
OLLAMA_API_CONFIGS: dict | |
async def update_config( | |
request: Request, form_data: OllamaConfigForm, user=Depends(get_admin_user) | |
): | |
request.app.state.config.ENABLE_OLLAMA_API = form_data.ENABLE_OLLAMA_API | |
request.app.state.config.OLLAMA_BASE_URLS = form_data.OLLAMA_BASE_URLS | |
request.app.state.config.OLLAMA_API_CONFIGS = form_data.OLLAMA_API_CONFIGS | |
# Remove any extra configs | |
config_urls = request.app.state.config.OLLAMA_API_CONFIGS.keys() | |
for url in list(request.app.state.config.OLLAMA_BASE_URLS): | |
if url not in config_urls: | |
request.app.state.config.OLLAMA_API_CONFIGS.pop(url, None) | |
return { | |
"ENABLE_OLLAMA_API": request.app.state.config.ENABLE_OLLAMA_API, | |
"OLLAMA_BASE_URLS": request.app.state.config.OLLAMA_BASE_URLS, | |
"OLLAMA_API_CONFIGS": request.app.state.config.OLLAMA_API_CONFIGS, | |
} | |
async def get_all_models(request: Request): | |
log.info("get_all_models()") | |
if request.app.state.config.ENABLE_OLLAMA_API: | |
request_tasks = [] | |
for idx, url in enumerate(request.app.state.config.OLLAMA_BASE_URLS): | |
if url not in request.app.state.config.OLLAMA_API_CONFIGS: | |
request_tasks.append(send_get_request(f"{url}/api/tags")) | |
else: | |
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) | |
enable = api_config.get("enable", True) | |
key = api_config.get("key", None) | |
if enable: | |
request_tasks.append(send_get_request(f"{url}/api/tags", key)) | |
else: | |
request_tasks.append(asyncio.ensure_future(asyncio.sleep(0, None))) | |
responses = await asyncio.gather(*request_tasks) | |
for idx, response in enumerate(responses): | |
if response: | |
url = request.app.state.config.OLLAMA_BASE_URLS[idx] | |
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) | |
prefix_id = api_config.get("prefix_id", None) | |
model_ids = api_config.get("model_ids", []) | |
if len(model_ids) != 0 and "models" in response: | |
response["models"] = list( | |
filter( | |
lambda model: model["model"] in model_ids, | |
response["models"], | |
) | |
) | |
if prefix_id: | |
for model in response.get("models", []): | |
model["model"] = f"{prefix_id}.{model['model']}" | |
def merge_models_lists(model_lists): | |
merged_models = {} | |
for idx, model_list in enumerate(model_lists): | |
if model_list is not None: | |
for model in model_list: | |
id = model["model"] | |
if id not in merged_models: | |
model["urls"] = [idx] | |
merged_models[id] = model | |
else: | |
merged_models[id]["urls"].append(idx) | |
return list(merged_models.values()) | |
models = { | |
"models": merge_models_lists( | |
map( | |
lambda response: response.get("models", []) if response else None, | |
responses, | |
) | |
) | |
} | |
else: | |
models = {"models": []} | |
request.app.state.OLLAMA_MODELS = { | |
model["model"]: model for model in models["models"] | |
} | |
return models | |
async def get_filtered_models(models, user): | |
# Filter models based on user access control | |
filtered_models = [] | |
for model in models.get("models", []): | |
model_info = Models.get_model_by_id(model["model"]) | |
if model_info: | |
if user.id == model_info.user_id or has_access( | |
user.id, type="read", access_control=model_info.access_control | |
): | |
filtered_models.append(model) | |
return filtered_models | |
async def get_ollama_tags( | |
request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user) | |
): | |
models = [] | |
if url_idx is None: | |
models = await get_all_models(request) | |
else: | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) | |
r = None | |
try: | |
r = requests.request( | |
method="GET", | |
url=f"{url}/api/tags", | |
headers={**({"Authorization": f"Bearer {key}"} if key else {})}, | |
) | |
r.raise_for_status() | |
models = r.json() | |
except Exception as e: | |
log.exception(e) | |
detail = None | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
detail = f"Ollama: {res['error']}" | |
except Exception: | |
detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, | |
detail=detail if detail else "Open WebUI: Server Connection Error", | |
) | |
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: | |
models["models"] = get_filtered_models(models, user) | |
return models | |
async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): | |
if request.app.state.config.ENABLE_OLLAMA_API: | |
if url_idx is None: | |
# returns lowest version | |
request_tasks = [ | |
send_get_request( | |
f"{url}/api/version", | |
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( | |
"key", None | |
), | |
) | |
for url in request.app.state.config.OLLAMA_BASE_URLS | |
] | |
responses = await asyncio.gather(*request_tasks) | |
responses = list(filter(lambda x: x is not None, responses)) | |
if len(responses) > 0: | |
lowest_version = min( | |
responses, | |
key=lambda x: tuple( | |
map(int, re.sub(r"^v|-.*", "", x["version"]).split(".")) | |
), | |
) | |
return {"version": lowest_version["version"]} | |
else: | |
raise HTTPException( | |
status_code=500, | |
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, | |
) | |
else: | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
r = None | |
try: | |
r = requests.request(method="GET", url=f"{url}/api/version") | |
r.raise_for_status() | |
return r.json() | |
except Exception as e: | |
log.exception(e) | |
detail = None | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
detail = f"Ollama: {res['error']}" | |
except Exception: | |
detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, | |
detail=detail if detail else "Open WebUI: Server Connection Error", | |
) | |
else: | |
return {"version": False} | |
async def get_ollama_loaded_models(request: Request, user=Depends(get_verified_user)): | |
""" | |
List models that are currently loaded into Ollama memory, and which node they are loaded on. | |
""" | |
if request.app.state.config.ENABLE_OLLAMA_API: | |
request_tasks = [ | |
send_get_request( | |
f"{url}/api/ps", | |
request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get( | |
"key", None | |
), | |
) | |
for url in request.app.state.config.OLLAMA_BASE_URLS | |
] | |
responses = await asyncio.gather(*request_tasks) | |
return dict(zip(request.app.state.config.OLLAMA_BASE_URLS, responses)) | |
else: | |
return {} | |
class ModelNameForm(BaseModel): | |
name: str | |
async def pull_model( | |
request: Request, | |
form_data: ModelNameForm, | |
url_idx: int = 0, | |
user=Depends(get_admin_user), | |
): | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
log.info(f"url: {url}") | |
# Admin should be able to pull models from any source | |
payload = {**form_data.model_dump(exclude_none=True), "insecure": True} | |
return await send_post_request( | |
url=f"{url}/api/pull", | |
payload=json.dumps(payload), | |
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), | |
) | |
class PushModelForm(BaseModel): | |
name: str | |
insecure: Optional[bool] = None | |
stream: Optional[bool] = None | |
async def push_model( | |
request: Request, | |
form_data: PushModelForm, | |
url_idx: Optional[int] = None, | |
user=Depends(get_admin_user), | |
): | |
if url_idx is None: | |
await get_all_models(request) | |
models = request.app.state.OLLAMA_MODELS | |
if form_data.name in models: | |
url_idx = models[form_data.name]["urls"][0] | |
else: | |
raise HTTPException( | |
status_code=400, | |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), | |
) | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
log.debug(f"url: {url}") | |
return await send_post_request( | |
url=f"{url}/api/push", | |
payload=form_data.model_dump_json(exclude_none=True).encode(), | |
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), | |
) | |
class CreateModelForm(BaseModel): | |
name: str | |
modelfile: Optional[str] = None | |
stream: Optional[bool] = None | |
path: Optional[str] = None | |
async def create_model( | |
request: Request, | |
form_data: CreateModelForm, | |
url_idx: int = 0, | |
user=Depends(get_admin_user), | |
): | |
log.debug(f"form_data: {form_data}") | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
return await send_post_request( | |
url=f"{url}/api/create", | |
payload=form_data.model_dump_json(exclude_none=True).encode(), | |
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), | |
) | |
class CopyModelForm(BaseModel): | |
source: str | |
destination: str | |
async def copy_model( | |
request: Request, | |
form_data: CopyModelForm, | |
url_idx: Optional[int] = None, | |
user=Depends(get_admin_user), | |
): | |
if url_idx is None: | |
await get_all_models(request) | |
models = request.app.state.OLLAMA_MODELS | |
if form_data.source in models: | |
url_idx = models[form_data.source]["urls"][0] | |
else: | |
raise HTTPException( | |
status_code=400, | |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), | |
) | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) | |
try: | |
r = requests.request( | |
method="POST", | |
url=f"{url}/api/copy", | |
headers={ | |
"Content-Type": "application/json", | |
**({"Authorization": f"Bearer {key}"} if key else {}), | |
}, | |
data=form_data.model_dump_json(exclude_none=True).encode(), | |
) | |
r.raise_for_status() | |
log.debug(f"r.text: {r.text}") | |
return True | |
except Exception as e: | |
log.exception(e) | |
detail = None | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
detail = f"Ollama: {res['error']}" | |
except Exception: | |
detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, | |
detail=detail if detail else "Open WebUI: Server Connection Error", | |
) | |
async def delete_model( | |
request: Request, | |
form_data: ModelNameForm, | |
url_idx: Optional[int] = None, | |
user=Depends(get_admin_user), | |
): | |
if url_idx is None: | |
await get_all_models(request) | |
models = request.app.state.OLLAMA_MODELS | |
if form_data.name in models: | |
url_idx = models[form_data.name]["urls"][0] | |
else: | |
raise HTTPException( | |
status_code=400, | |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), | |
) | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) | |
try: | |
r = requests.request( | |
method="DELETE", | |
url=f"{url}/api/delete", | |
data=form_data.model_dump_json(exclude_none=True).encode(), | |
headers={ | |
"Content-Type": "application/json", | |
**({"Authorization": f"Bearer {key}"} if key else {}), | |
}, | |
) | |
r.raise_for_status() | |
log.debug(f"r.text: {r.text}") | |
return True | |
except Exception as e: | |
log.exception(e) | |
detail = None | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
detail = f"Ollama: {res['error']}" | |
except Exception: | |
detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, | |
detail=detail if detail else "Open WebUI: Server Connection Error", | |
) | |
async def show_model_info( | |
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) | |
): | |
await get_all_models(request) | |
models = request.app.state.OLLAMA_MODELS | |
if form_data.name not in models: | |
raise HTTPException( | |
status_code=400, | |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), | |
) | |
url_idx = random.choice(models[form_data.name]["urls"]) | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) | |
try: | |
r = requests.request( | |
method="POST", | |
url=f"{url}/api/show", | |
headers={ | |
"Content-Type": "application/json", | |
**({"Authorization": f"Bearer {key}"} if key else {}), | |
}, | |
data=form_data.model_dump_json(exclude_none=True).encode(), | |
) | |
r.raise_for_status() | |
return r.json() | |
except Exception as e: | |
log.exception(e) | |
detail = None | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
detail = f"Ollama: {res['error']}" | |
except Exception: | |
detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, | |
detail=detail if detail else "Open WebUI: Server Connection Error", | |
) | |
class GenerateEmbedForm(BaseModel): | |
model: str | |
input: list[str] | str | |
truncate: Optional[bool] = None | |
options: Optional[dict] = None | |
keep_alive: Optional[Union[int, str]] = None | |
async def embed( | |
request: Request, | |
form_data: GenerateEmbedForm, | |
url_idx: Optional[int] = None, | |
user=Depends(get_verified_user), | |
): | |
log.info(f"generate_ollama_batch_embeddings {form_data}") | |
if url_idx is None: | |
await get_all_models(request) | |
models = request.app.state.OLLAMA_MODELS | |
model = form_data.model | |
if ":" not in model: | |
model = f"{model}:latest" | |
if model in models: | |
url_idx = random.choice(models[model]["urls"]) | |
else: | |
raise HTTPException( | |
status_code=400, | |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | |
) | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) | |
try: | |
r = requests.request( | |
method="POST", | |
url=f"{url}/api/embed", | |
headers={ | |
"Content-Type": "application/json", | |
**({"Authorization": f"Bearer {key}"} if key else {}), | |
}, | |
data=form_data.model_dump_json(exclude_none=True).encode(), | |
) | |
r.raise_for_status() | |
data = r.json() | |
return data | |
except Exception as e: | |
log.exception(e) | |
detail = None | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
detail = f"Ollama: {res['error']}" | |
except Exception: | |
detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, | |
detail=detail if detail else "Open WebUI: Server Connection Error", | |
) | |
class GenerateEmbeddingsForm(BaseModel): | |
model: str | |
prompt: str | |
options: Optional[dict] = None | |
keep_alive: Optional[Union[int, str]] = None | |
async def embeddings( | |
request: Request, | |
form_data: GenerateEmbeddingsForm, | |
url_idx: Optional[int] = None, | |
user=Depends(get_verified_user), | |
): | |
log.info(f"generate_ollama_embeddings {form_data}") | |
if url_idx is None: | |
await get_all_models(request) | |
models = request.app.state.OLLAMA_MODELS | |
model = form_data.model | |
if ":" not in model: | |
model = f"{model}:latest" | |
if model in models: | |
url_idx = random.choice(models[model]["urls"]) | |
else: | |
raise HTTPException( | |
status_code=400, | |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | |
) | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) | |
try: | |
r = requests.request( | |
method="POST", | |
url=f"{url}/api/embeddings", | |
headers={ | |
"Content-Type": "application/json", | |
**({"Authorization": f"Bearer {key}"} if key else {}), | |
}, | |
data=form_data.model_dump_json(exclude_none=True).encode(), | |
) | |
r.raise_for_status() | |
data = r.json() | |
return data | |
except Exception as e: | |
log.exception(e) | |
detail = None | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
detail = f"Ollama: {res['error']}" | |
except Exception: | |
detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, | |
detail=detail if detail else "Open WebUI: Server Connection Error", | |
) | |
class GenerateCompletionForm(BaseModel): | |
model: str | |
prompt: str | |
suffix: Optional[str] = None | |
images: Optional[list[str]] = None | |
format: Optional[str] = None | |
options: Optional[dict] = None | |
system: Optional[str] = None | |
template: Optional[str] = None | |
context: Optional[list[int]] = None | |
stream: Optional[bool] = True | |
raw: Optional[bool] = None | |
keep_alive: Optional[Union[int, str]] = None | |
async def generate_completion( | |
request: Request, | |
form_data: GenerateCompletionForm, | |
url_idx: Optional[int] = None, | |
user=Depends(get_verified_user), | |
): | |
if url_idx is None: | |
await get_all_models(request) | |
models = request.app.state.OLLAMA_MODELS | |
model = form_data.model | |
if ":" not in model: | |
model = f"{model}:latest" | |
if model in models: | |
url_idx = random.choice(models[model]["urls"]) | |
else: | |
raise HTTPException( | |
status_code=400, | |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), | |
) | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) | |
prefix_id = api_config.get("prefix_id", None) | |
if prefix_id: | |
form_data.model = form_data.model.replace(f"{prefix_id}.", "") | |
return await send_post_request( | |
url=f"{url}/api/generate", | |
payload=form_data.model_dump_json(exclude_none=True).encode(), | |
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), | |
) | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
images: Optional[list[str]] = None | |
class GenerateChatCompletionForm(BaseModel): | |
model: str | |
messages: list[ChatMessage] | |
format: Optional[dict] = None | |
options: Optional[dict] = None | |
template: Optional[str] = None | |
stream: Optional[bool] = True | |
keep_alive: Optional[Union[int, str]] = None | |
async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None): | |
if url_idx is None: | |
models = request.app.state.OLLAMA_MODELS | |
if model not in models: | |
raise HTTPException( | |
status_code=400, | |
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), | |
) | |
url_idx = random.choice(models[model].get("urls", [])) | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
return url | |
async def generate_chat_completion( | |
request: Request, | |
form_data: dict, | |
url_idx: Optional[int] = None, | |
user=Depends(get_verified_user), | |
bypass_filter: Optional[bool] = False, | |
): | |
if BYPASS_MODEL_ACCESS_CONTROL: | |
bypass_filter = True | |
try: | |
form_data = GenerateChatCompletionForm(**form_data) | |
except Exception as e: | |
log.exception(e) | |
raise HTTPException( | |
status_code=400, | |
detail=str(e), | |
) | |
payload = {**form_data.model_dump(exclude_none=True)} | |
if "metadata" in payload: | |
del payload["metadata"] | |
model_id = payload["model"] | |
model_info = Models.get_model_by_id(model_id) | |
if model_info: | |
if model_info.base_model_id: | |
payload["model"] = model_info.base_model_id | |
params = model_info.params.model_dump() | |
if params: | |
if payload.get("options") is None: | |
payload["options"] = {} | |
payload["options"] = apply_model_params_to_body_ollama( | |
params, payload["options"] | |
) | |
payload = apply_model_system_prompt_to_body(params, payload, user) | |
# Check if user has access to the model | |
if not bypass_filter and user.role == "user": | |
if not ( | |
user.id == model_info.user_id | |
or has_access( | |
user.id, type="read", access_control=model_info.access_control | |
) | |
): | |
raise HTTPException( | |
status_code=403, | |
detail="Model not found", | |
) | |
elif not bypass_filter: | |
if user.role != "admin": | |
raise HTTPException( | |
status_code=403, | |
detail="Model not found", | |
) | |
if ":" not in payload["model"]: | |
payload["model"] = f"{payload['model']}:latest" | |
url = await get_ollama_url(request, payload["model"], url_idx) | |
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) | |
prefix_id = api_config.get("prefix_id", None) | |
if prefix_id: | |
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") | |
return await send_post_request( | |
url=f"{url}/api/chat", | |
payload=json.dumps(payload), | |
stream=form_data.stream, | |
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), | |
content_type="application/x-ndjson", | |
) | |
# TODO: we should update this part once Ollama supports other types | |
class OpenAIChatMessageContent(BaseModel): | |
type: str | |
model_config = ConfigDict(extra="allow") | |
class OpenAIChatMessage(BaseModel): | |
role: str | |
content: Union[str, list[OpenAIChatMessageContent]] | |
model_config = ConfigDict(extra="allow") | |
class OpenAIChatCompletionForm(BaseModel): | |
model: str | |
messages: list[OpenAIChatMessage] | |
model_config = ConfigDict(extra="allow") | |
class OpenAICompletionForm(BaseModel): | |
model: str | |
prompt: str | |
model_config = ConfigDict(extra="allow") | |
async def generate_openai_completion( | |
request: Request, | |
form_data: dict, | |
url_idx: Optional[int] = None, | |
user=Depends(get_verified_user), | |
): | |
try: | |
form_data = OpenAICompletionForm(**form_data) | |
except Exception as e: | |
log.exception(e) | |
raise HTTPException( | |
status_code=400, | |
detail=str(e), | |
) | |
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} | |
if "metadata" in payload: | |
del payload["metadata"] | |
model_id = form_data.model | |
if ":" not in model_id: | |
model_id = f"{model_id}:latest" | |
model_info = Models.get_model_by_id(model_id) | |
if model_info: | |
if model_info.base_model_id: | |
payload["model"] = model_info.base_model_id | |
params = model_info.params.model_dump() | |
if params: | |
payload = apply_model_params_to_body_openai(params, payload) | |
# Check if user has access to the model | |
if user.role == "user": | |
if not ( | |
user.id == model_info.user_id | |
or has_access( | |
user.id, type="read", access_control=model_info.access_control | |
) | |
): | |
raise HTTPException( | |
status_code=403, | |
detail="Model not found", | |
) | |
else: | |
if user.role != "admin": | |
raise HTTPException( | |
status_code=403, | |
detail="Model not found", | |
) | |
if ":" not in payload["model"]: | |
payload["model"] = f"{payload['model']}:latest" | |
url = await get_ollama_url(request, payload["model"], url_idx) | |
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) | |
prefix_id = api_config.get("prefix_id", None) | |
if prefix_id: | |
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") | |
return await send_post_request( | |
url=f"{url}/v1/completions", | |
payload=json.dumps(payload), | |
stream=payload.get("stream", False), | |
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), | |
) | |
async def generate_openai_chat_completion( | |
request: Request, | |
form_data: dict, | |
url_idx: Optional[int] = None, | |
user=Depends(get_verified_user), | |
): | |
try: | |
completion_form = OpenAIChatCompletionForm(**form_data) | |
except Exception as e: | |
log.exception(e) | |
raise HTTPException( | |
status_code=400, | |
detail=str(e), | |
) | |
payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])} | |
if "metadata" in payload: | |
del payload["metadata"] | |
model_id = completion_form.model | |
if ":" not in model_id: | |
model_id = f"{model_id}:latest" | |
model_info = Models.get_model_by_id(model_id) | |
if model_info: | |
if model_info.base_model_id: | |
payload["model"] = model_info.base_model_id | |
params = model_info.params.model_dump() | |
if params: | |
payload = apply_model_params_to_body_openai(params, payload) | |
payload = apply_model_system_prompt_to_body(params, payload, user) | |
# Check if user has access to the model | |
if user.role == "user": | |
if not ( | |
user.id == model_info.user_id | |
or has_access( | |
user.id, type="read", access_control=model_info.access_control | |
) | |
): | |
raise HTTPException( | |
status_code=403, | |
detail="Model not found", | |
) | |
else: | |
if user.role != "admin": | |
raise HTTPException( | |
status_code=403, | |
detail="Model not found", | |
) | |
if ":" not in payload["model"]: | |
payload["model"] = f"{payload['model']}:latest" | |
url = await get_ollama_url(request, payload["model"], url_idx) | |
api_config = request.app.state.config.OLLAMA_API_CONFIGS.get(url, {}) | |
prefix_id = api_config.get("prefix_id", None) | |
if prefix_id: | |
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") | |
return await send_post_request( | |
url=f"{url}/v1/chat/completions", | |
payload=json.dumps(payload), | |
stream=payload.get("stream", False), | |
key=get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS), | |
) | |
async def get_openai_models( | |
request: Request, | |
url_idx: Optional[int] = None, | |
user=Depends(get_verified_user), | |
): | |
models = [] | |
if url_idx is None: | |
model_list = await get_all_models(request) | |
models = [ | |
{ | |
"id": model["model"], | |
"object": "model", | |
"created": int(time.time()), | |
"owned_by": "openai", | |
} | |
for model in model_list["models"] | |
] | |
else: | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
try: | |
r = requests.request(method="GET", url=f"{url}/api/tags") | |
r.raise_for_status() | |
model_list = r.json() | |
models = [ | |
{ | |
"id": model["model"], | |
"object": "model", | |
"created": int(time.time()), | |
"owned_by": "openai", | |
} | |
for model in models["models"] | |
] | |
except Exception as e: | |
log.exception(e) | |
error_detail = "Open WebUI: Server Connection Error" | |
if r is not None: | |
try: | |
res = r.json() | |
if "error" in res: | |
error_detail = f"Ollama: {res['error']}" | |
except Exception: | |
error_detail = f"Ollama: {e}" | |
raise HTTPException( | |
status_code=r.status_code if r else 500, | |
detail=error_detail, | |
) | |
if user.role == "user" and not BYPASS_MODEL_ACCESS_CONTROL: | |
# Filter models based on user access control | |
filtered_models = [] | |
for model in models: | |
model_info = Models.get_model_by_id(model["id"]) | |
if model_info: | |
if user.id == model_info.user_id or has_access( | |
user.id, type="read", access_control=model_info.access_control | |
): | |
filtered_models.append(model) | |
models = filtered_models | |
return { | |
"data": models, | |
"object": "list", | |
} | |
class UrlForm(BaseModel): | |
url: str | |
class UploadBlobForm(BaseModel): | |
filename: str | |
def parse_huggingface_url(hf_url): | |
try: | |
# Parse the URL | |
parsed_url = urlparse(hf_url) | |
# Get the path and split it into components | |
path_components = parsed_url.path.split("/") | |
# Extract the desired output | |
model_file = path_components[-1] | |
return model_file | |
except ValueError: | |
return None | |
async def download_file_stream( | |
ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024 | |
): | |
done = False | |
if os.path.exists(file_path): | |
current_size = os.path.getsize(file_path) | |
else: | |
current_size = 0 | |
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {} | |
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout | |
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: | |
async with session.get(file_url, headers=headers) as response: | |
total_size = int(response.headers.get("content-length", 0)) + current_size | |
with open(file_path, "ab+") as file: | |
async for data in response.content.iter_chunked(chunk_size): | |
current_size += len(data) | |
file.write(data) | |
done = current_size == total_size | |
progress = round((current_size / total_size) * 100, 2) | |
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n' | |
if done: | |
file.seek(0) | |
hashed = calculate_sha256(file) | |
file.seek(0) | |
url = f"{ollama_url}/api/blobs/sha256:{hashed}" | |
response = requests.post(url, data=file) | |
if response.ok: | |
res = { | |
"done": done, | |
"blob": f"sha256:{hashed}", | |
"name": file_name, | |
} | |
os.remove(file_path) | |
yield f"data: {json.dumps(res)}\n\n" | |
else: | |
raise "Ollama: Could not create blob, Please try again." | |
# url = "https://huggingface.co./TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" | |
async def download_model( | |
request: Request, | |
form_data: UrlForm, | |
url_idx: Optional[int] = None, | |
user=Depends(get_admin_user), | |
): | |
allowed_hosts = ["https://huggingface.co./", "https://github.com/"] | |
if not any(form_data.url.startswith(host) for host in allowed_hosts): | |
raise HTTPException( | |
status_code=400, | |
detail="Invalid file_url. Only URLs from allowed hosts are permitted.", | |
) | |
if url_idx is None: | |
url_idx = 0 | |
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
file_name = parse_huggingface_url(form_data.url) | |
if file_name: | |
file_path = f"{UPLOAD_DIR}/{file_name}" | |
return StreamingResponse( | |
download_file_stream(url, form_data.url, file_path, file_name), | |
) | |
else: | |
return None | |
def upload_model( | |
request: Request, | |
file: UploadFile = File(...), | |
url_idx: Optional[int] = None, | |
user=Depends(get_admin_user), | |
): | |
if url_idx is None: | |
url_idx = 0 | |
ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] | |
file_path = f"{UPLOAD_DIR}/{file.filename}" | |
# Save file in chunks | |
with open(file_path, "wb+") as f: | |
for chunk in file.file: | |
f.write(chunk) | |
def file_process_stream(): | |
nonlocal ollama_url | |
total_size = os.path.getsize(file_path) | |
chunk_size = 1024 * 1024 | |
try: | |
with open(file_path, "rb") as f: | |
total = 0 | |
done = False | |
while not done: | |
chunk = f.read(chunk_size) | |
if not chunk: | |
done = True | |
continue | |
total += len(chunk) | |
progress = round((total / total_size) * 100, 2) | |
res = { | |
"progress": progress, | |
"total": total_size, | |
"completed": total, | |
} | |
yield f"data: {json.dumps(res)}\n\n" | |
if done: | |
f.seek(0) | |
hashed = calculate_sha256(f) | |
f.seek(0) | |
url = f"{ollama_url}/api/blobs/sha256:{hashed}" | |
response = requests.post(url, data=f) | |
if response.ok: | |
res = { | |
"done": done, | |
"blob": f"sha256:{hashed}", | |
"name": file.filename, | |
} | |
os.remove(file_path) | |
yield f"data: {json.dumps(res)}\n\n" | |
else: | |
raise Exception( | |
"Ollama: Could not create blob, Please try again." | |
) | |
except Exception as e: | |
res = {"error": str(e)} | |
yield f"data: {json.dumps(res)}\n\n" | |
return StreamingResponse(file_process_stream(), media_type="text/event-stream") | |