import gradio as gr from sentence_transformers import SentenceTransformer import os import time import threading import queue import psycopg2 import zlib import numpy as np from urllib.parse import urlparse import logging from sklearn.preprocessing import normalize from concurrent.futures import ThreadPoolExecutor import requests from fastapi import FastAPI, HTTPException, Query from typing import List, Optional import uvicorn from starlette.requests import Request from starlette.responses import HTMLResponse, JSONResponse from fastapi.responses import HTMLResponse from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles # Настройка логирования logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Настройки базы данных PostgreSQL DATABASE_URL = os.environ.get("DATABASE_URL") if DATABASE_URL is None: raise ValueError("DATABASE_URL environment variable not set.") parsed_url = urlparse(DATABASE_URL) db_params = { "host": parsed_url.hostname, "port": parsed_url.port, "database": parsed_url.path.lstrip("/"), "user": parsed_url.username, "password": parsed_url.password, "sslmode": "require" } # Загружаем модель эмбеддингов model_name = "BAAI/bge-m3" logging.info(f"Загрузка модели {model_name}...") model = SentenceTransformer(model_name) logging.info("Модель загружена успешно.") # Jina AI Reranker API JINA_API_URL = 'https://api.jina.ai/v1/rerank' JINA_API_KEY = os.environ.get("JINA_API_KEY") # Используем переменную окружения if JINA_API_KEY is None: raise ValueError("JINA_API_KEY environment variable not set.") JINA_RERANKER_MODEL = "jina-reranker-v2-base-multilingual" # Jina AI Dashboard API JINA_DASHBOARD_API_URL = 'https://embeddings-dashboard-api.jina.ai/api/v1/api_key/user' # Имена таблиц embeddings_table = "movie_embeddings" query_cache_table = "query_cache" movies_table = "Movies" # Имя таблицы с фильмами # FastAPI приложение app = FastAPI() # Разрешаем CORS, чтобы Gradio мог обращаться к API app.add_middleware( CORSMiddleware, allow_origins=["*"], # Разрешаем все источники, в продакшене лучше указать конкретные allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) def get_db_connection(): """Устанавливает соединение с базой данных.""" try: conn = psycopg2.connect(**db_params) return conn except Exception as e: logging.error(f"Ошибка подключения к базе данных: {e}") return None def setup_database(): """Настраивает базу данных: создает расширение, таблицы и индексы.""" conn = get_db_connection() if conn is None: return try: with conn.cursor() as cur: # Создаем расширение pgvector если его нет cur.execute("CREATE EXTENSION IF NOT EXISTS vector;") # Создаем таблицу для хранения эмбеддингов фильмов cur.execute(f""" CREATE TABLE IF NOT EXISTS "{embeddings_table}" ( movie_id INTEGER PRIMARY KEY, embedding_crc32 BIGINT, string_crc32 BIGINT, model_name TEXT, embedding vector(1024) ); CREATE INDEX IF NOT EXISTS idx_string_crc32 ON "{embeddings_table}" (string_crc32); """) # Создаем таблицу для кэширования запросов cur.execute(f""" CREATE TABLE IF NOT EXISTS "{query_cache_table}" ( query_crc32 BIGINT PRIMARY KEY, query TEXT, model_name TEXT, embedding vector(1024), created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP ); CREATE INDEX IF NOT EXISTS idx_query_crc32 ON "{query_cache_table}" (query_crc32); CREATE INDEX IF NOT EXISTS idx_created_at ON "{query_cache_table}" (created_at); """) conn.commit() logging.info("База данных успешно настроена.") except Exception as e: logging.error(f"Ошибка при настройке базы данных: {e}") conn.rollback() finally: conn.close() # Настраиваем базу данных при запуске setup_database() def calculate_crc32(text): """Вычисляет CRC32 для строки.""" return zlib.crc32(text.encode('utf-8')) & 0xFFFFFFFF def encode_string(text): """Кодирует строку в эмбеддинг.""" embedding = model.encode(text, convert_to_tensor=True, normalize_embeddings=True) return embedding.cpu().numpy() def get_embedding_from_db(conn, table_name, crc32_column, crc32_value, model_name): """Получает эмбеддинг из базы данных.""" try: with conn.cursor() as cur: cur.execute(f"SELECT embedding FROM \"{table_name}\" WHERE \"{crc32_column}\" = %s AND model_name = %s", (crc32_value, model_name)) result = cur.fetchone() if result and result[0]: # Нормализуем эмбеддинг после извлечения из БД return normalize(np.array(result[0]).reshape(1, -1))[0] except Exception as e: logging.error(f"Ошибка при получении эмбеддинга из БД: {e}") return None def get_movie_data_from_db(conn, movie_ids): """ Получает данные фильмов из таблицы Movies по списку ID, включая предположительно URL-адрес постера и рейтинг. """ movie_data_dict = {} try: with conn.cursor() as cur: cur.execute(f""" SELECT id, data, jsonb_build_object( 'Название', data->>'name', 'Год', data->>'year', 'Жанры', (SELECT string_agg(genre->>'name', ', ') FROM jsonb_array_elements(data->'genres') AS genre), 'Описание', COALESCE(data->>'description', ''), 'Постер', data->'poster'->'previewUrl', 'Рейтинг', data->'rating'->'kp' ) AS prepared_json FROM "{movies_table}" WHERE id IN %s """, (tuple(movie_ids),)) for movie_id, movie_data, prepared_json in cur.fetchall(): # Исправлено: убрано формирование prepared_string, так как оно больше не используется для вывода relevance_score movie_data_dict[movie_id] = (movie_data, prepared_json) except Exception as e: logging.error(f"Ошибка при получении данных фильмов из БД: {e}") return movie_data_dict def get_jina_ai_balance(api_key: str): """Получает остаток баланса Jina AI.""" try: headers = { 'Content-Type': 'application/json' } params = { 'api_key': api_key } response = requests.get(JINA_DASHBOARD_API_URL, headers=headers, params=params) response.raise_for_status() data = response.json() return data['wallet']['total_balance'] except requests.exceptions.RequestException as e: logging.error(f"Ошибка при запросе к API баланса Jina AI: {e}") return None def rerank_with_api(query, results, top_k, rerank_top_k=None, api_key=None): """Переранжирует результаты с помощью Jina AI Reranker API.""" logging.info(f"Начало переранжирования для запроса: '{query}'") # Если rerank_top_k равен 0, не используем реранкер if rerank_top_k == 0: logging.info("Переранжирование отключено (rerank_top_k = 0).") return results, False, 0 # Получаем данные фильмов из БД conn = get_db_connection() movie_ids = [movie_id for movie_id, _ in results] movie_data_dict = get_movie_data_from_db(conn, movie_ids) conn.close() documents = [] for movie_id, _ in results: movie_data, prepared_json = movie_data_dict.get(movie_id, (None, None)) if movie_data: # Исправлено: добавлено формирование строки, так как она используется в data prepared_string = ( f"Название: {prepared_json['Название']}\n" f"Год: {prepared_json['Год']}\n" f"Жанры: {prepared_json['Жанры']}\n" f"Описание: {prepared_json['Описание']}" ) documents.append(prepared_string) else: logging.warning(f"Данные для фильма с ID {movie_id} не найдены в БД.") reranked_count = min(rerank_top_k or top_k*2, len(documents)) headers = { 'Content-Type': 'application/json', 'Authorization': f'Bearer {api_key or JINA_API_KEY}' } data = { "model": JINA_RERANKER_MODEL, "query": query, "top_n": rerank_top_k or top_k*2, "documents": documents } logging.info(f"Отправка данных на реранжировку (documents count): {len(data['documents'])}, top_n: {data['top_n']}") try: response = requests.post(JINA_API_URL, headers=headers, json=data) response.raise_for_status() result = response.json() logging.info(f"Ответ от API реранжировщика получен.") reranked_results = [] if 'results' in result: for item in result['results']: index = item['index'] movie_id = results[index][0] reranked_results.append((movie_id, item['relevance_score'])) else: logging.warning("Ответ от API не содержит ключа 'results'.") logging.info("Переранжирование завершено.") return reranked_results, True, reranked_count except requests.exceptions.RequestException as e: logging.error(f"Ошибка при запросе к API реранжировщика: {e}") return results, False, reranked_count def search_movies_internal(query: str, top_k: int = 25, rerank_top_k: Optional[int] = None, jina_api_key: Optional[str] = None): """Внутренняя функция для поиска фильмов по запросу (используется и в Gradio, и в API).""" start_time = time.time() try: conn = get_db_connection() if conn is None: raise Exception("Ошибка подключения к базе данных") query_crc32 = calculate_crc32(query) query_embedding = get_embedding_from_db(conn, query_cache_table, "query_crc32", query_crc32, model_name) if query_embedding is None: query_embedding = encode_string(query) try: with conn.cursor() as cur: cur.execute(f""" INSERT INTO "{query_cache_table}" (query_crc32, query, model_name, embedding) VALUES (%s, %s, %s, %s) ON CONFLICT (query_crc32) DO NOTHING """, (query_crc32, query, model_name, query_embedding.tolist())) conn.commit() logging.info(f"Сохранен новый эмбеддинг запроса: {query}") except Exception as e: logging.error(f"Ошибка при сохранении эмбеддинга запроса: {e}") conn.rollback() # Определяем количество фильмов для запроса из БД db_limit = rerank_top_k or top_k * 2 # Модифицируем запрос для поддержки поиска по числовому идентификатору try: with conn.cursor() as cur: if query.isdigit(): # Если запрос является числом, ищем по ID cur.execute(f""" SELECT m.movie_id, 1.0 as similarity FROM "{embeddings_table}" m WHERE m.movie_id = %s LIMIT 1 """, (int(query),)) results = cur.fetchall() logging.info(f"Найдено {len(results)} результатов по ID.") else: cur.execute(f""" WITH query_embedding AS ( SELECT embedding FROM "{query_cache_table}" WHERE query_crc32 = %s ) SELECT m.movie_id, 1 - (m.embedding <=> (SELECT embedding FROM query_embedding)) as similarity FROM "{embeddings_table}" m, query_embedding ORDER BY similarity DESC LIMIT %s """, (query_crc32, int(db_limit))) results = cur.fetchall() logging.info(f"Найдено {len(results)} предварительных результатов поиска по тексту.") except Exception as e: logging.error(f"Ошибка при выполнении поискового запроса: {e}") results = [] finally: conn.close() # Используем реранкер только если rerank_top_k не равен 0 if rerank_top_k != 0: reranked_results, rerank_success, reranked_count = rerank_with_api(query, results, top_k, rerank_top_k, jina_api_key) else: reranked_results = results rerank_success = False reranked_count = 0 if not rerank_success: logging.warning("Переранжировка не удалась, используются сырые результаты.") reranked_results = results[:top_k] # Используем срез для ограничения количества результатов else: reranked_results = reranked_results[:top_k] conn = get_db_connection() movie_ids = [movie_id for movie_id, _ in reranked_results] movie_data_dict = get_movie_data_from_db(conn, movie_ids) # Получаем общее количество фильмов в базе данных try: with conn.cursor() as cur: cur.execute(f'SELECT COUNT(*) FROM "{movies_table}"') total_movies = cur.fetchone()[0] except Exception as e: logging.error(f"Ошибка при получении общего количества фильмов: {e}") total_movies = 0 # Получаем количество фильмов, по которым производился поиск (количество строк в movie_embeddings) try: with conn.cursor() as cur: cur.execute(f'SELECT COUNT(*) FROM "{embeddings_table}"') searched_movies = cur.fetchone()[0] except Exception as e: logging.error(f"Ошибка при получении количества фильмов для поиска: {e}") searched_movies = 0 finally: conn.close() formatted_results = [] for movie_id, score in reranked_results: movie_data, prepared_json = movie_data_dict.get(movie_id, (None, None)) if movie_data: formatted_results.append({ "movie_id": movie_id, "name": prepared_json['Название'], "year": prepared_json['Год'], "genres": prepared_json['Жанры'], "description": prepared_json['Описание'], "poster_preview_url": prepared_json['Постер'], "rating_kp": prepared_json['Рейтинг'], "relevance_score": score # Убрано условие `if rerank_success else 0.0` и всегда возвращаем score }) else: logging.warning(f"Данные для фильма с ID {movie_id} не найдены в БД.") search_time = time.time() - start_time logging.info(f"Поиск выполнен за {search_time:.2f} секунд.") jina_balance = get_jina_ai_balance(jina_api_key or JINA_API_KEY) return { "status": "success", "results": formatted_results, "search_time": search_time, "total_movies": total_movies, "searched_movies": searched_movies, "returned_movies": len(formatted_results), # Количество возвращенных фильмов "reranked_movies": reranked_count, # Количество фильмов, обработанных реранкером "jina_balance": jina_balance # Остаток баланса Jina AI }, search_time except Exception as e: logging.error(f"Ошибка при выполнении поиска: {e}") return { "status": "error", "message": str(e), "search_time": 0, "total_movies": 0, "searched_movies": 0, "returned_movies": 0, "reranked_movies": 0, "jina_balance": None }, 0 @app.get("/search/", response_model=dict) async def api_search_movies(query: str = Query(..., description="Поисковый запрос"), top_k: int = Query(25, description="Количество возвращаемых результатов"), rerank_top_k: Optional[int] = Query(None, description="Количество фильмов для передачи в реранкер (если не указано, то top_k*2)"), jina_api_key: Optional[str] = Query(None, description="API ключ Jina AI (если не указан, используется значение из переменной окружения JINA_API_KEY)")): """ API endpoint для поиска фильмов. Parameters ---------- query : str Поисковый запрос. top_k : int, optional Количество возвращаемых результатов, по умолчанию 25. rerank_top_k : Optional[int], optional Количество фильмов для передачи в реранкер. Если 0 - реранкер не используется. Если не указано, то используется top_k*2. По умолчанию None. jina_api_key : Optional[str], optional API ключ Jina AI. Если не указан, используется значение из переменной окружения JINA_API_KEY. По умолчанию None. Returns ------- dict Словарь с результатами поиска. """ try: results, _ = search_movies_internal(query, top_k, rerank_top_k, jina_api_key) return results except Exception as e: raise HTTPException(status_code=500, detail=str(e)) # Рут-эндпоинт, который отдаёт HTML-страницу @app.get("/", response_class=HTMLResponse) async def root(): return """ VooFlex
""" # Рут-эндпоинт для демонстрации, что FastAPI работает @app.get("/api") async def root(): return {"message": "FastAPI is running. Access the API documentation at /docs"} # Запускаем FastAPI if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)