opex792's picture
Update app.py
bbf6f5b verified
raw
history blame
20.1 kB
import gradio as gr
from sentence_transformers import SentenceTransformer
import os
import time
import threading
import queue
import psycopg2
import zlib
import numpy as np
from urllib.parse import urlparse
import logging
from sklearn.preprocessing import normalize
from concurrent.futures import ThreadPoolExecutor
import requests
# Настройка логирования
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Настройки базы данных PostgreSQL
DATABASE_URL = os.environ.get("DATABASE_URL")
if DATABASE_URL is None:
raise ValueError("DATABASE_URL environment variable not set.")
parsed_url = urlparse(DATABASE_URL)
db_params = {
"host": parsed_url.hostname,
"port": parsed_url.port,
"database": parsed_url.path.lstrip("/"),
"user": parsed_url.username,
"password": parsed_url.password,
"sslmode": "require"
}
# Загружаем модель эмбеддингов
model_name = "BAAI/bge-m3"
logging.info(f"Загрузка модели {model_name}...")
model = SentenceTransformer(model_name)
logging.info("Модель загружена успешно.")
# Jina AI Reranker API
JINA_API_URL = 'https://api.jina.ai/v1/rerank'
JINA_API_KEY = os.environ.get("JINA_API_KEY")
if JINA_API_KEY is None:
raise ValueError("JINA_API_KEY environment variable not set.")
JINA_RERANKER_MODEL = "jina-reranker-v2-base-multilingual"
# Имена таблиц
embeddings_table = "movie_embeddings"
query_cache_table = "query_cache"
movies_table = "Movies" # Имя таблицы с фильмами
# Максимальный размер таблицы кэша запросов в байтах (50MB)
MAX_CACHE_SIZE = 50 * 1024 * 1024
# Очередь для необработанных фильмов
movies_queue = queue.Queue()
# Флаг, указывающий, что обработка фильмов завершена
processing_complete = False
# Флаг, указывающий, что выполняется поиск
search_in_progress = False
# Блокировка для доступа к базе данных
db_lock = threading.Lock()
# Размер пакета для обработки эмбеддингов
batch_size = 32
# Количество потоков для параллельной обработки
num_threads = 5
def get_db_connection():
"""Устанавливает соединение с базой данных."""
try:
conn = psycopg2.connect(**db_params)
return conn
except Exception as e:
logging.error(f"Ошибка подключения к базе данных: {e}")
return None
def setup_database():
"""Настраивает базу данных: создает расширение, таблицы и индексы."""
conn = get_db_connection()
if conn is None:
return
try:
with conn.cursor() as cur:
# Создаем расширение pgvector если его нет
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# Создаем таблицу для хранения эмбеддингов фильмов
cur.execute(f"""
CREATE TABLE IF NOT EXISTS "{embeddings_table}" (
movie_id INTEGER PRIMARY KEY,
embedding_crc32 BIGINT,
string_crc32 BIGINT,
model_name TEXT,
embedding vector(1024)
);
CREATE INDEX IF NOT EXISTS idx_string_crc32 ON "{embeddings_table}" (string_crc32);
""")
# Создаем таблицу для кэширования запросов
cur.execute(f"""
CREATE TABLE IF NOT EXISTS "{query_cache_table}" (
query_crc32 BIGINT PRIMARY KEY,
query TEXT,
model_name TEXT,
embedding vector(1024),
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_query_crc32 ON "{query_cache_table}" (query_crc32);
CREATE INDEX IF NOT EXISTS idx_created_at ON "{query_cache_table}" (created_at);
""")
conn.commit()
logging.info("База данных успешно настроена.")
except Exception as e:
logging.error(f"Ошибка при настройке базы данных: {e}")
conn.rollback()
finally:
conn.close()
# Настраиваем базу данных при запуске
setup_database()
def calculate_crc32(text):
"""Вычисляет CRC32 для строки."""
return zlib.crc32(text.encode('utf-8')) & 0xFFFFFFFF
def encode_string(text):
"""Кодирует строку в эмбеддинг."""
embedding = model.encode(text, convert_to_tensor=True, normalize_embeddings=True)
return embedding.cpu().numpy()
def get_movies_without_embeddings():
"""Получает список фильмов, для которых нужно создать эмбеддинги."""
conn = get_db_connection()
if conn is None:
return []
movies_to_process = []
try:
with conn.cursor() as cur:
# Получаем список ID фильмов, которые уже есть в таблице эмбеддингов
cur.execute(f"SELECT movie_id FROM \"{embeddings_table}\"")
existing_ids = {row[0] for row in cur.fetchall()}
# Получаем список всех фильмов из таблицы Movies с подготовленной строкой
cur.execute(f"""
SELECT id, data,
jsonb_build_object(
'Название', data->>'name',
'Год', data->>'year',
'Жанры', (SELECT string_agg(genre->>'name', ', ') FROM jsonb_array_elements(data->'genres') AS genre),
'Описание', COALESCE(data->>'description', '')
) AS prepared_json
FROM "{movies_table}"
""")
all_movies = cur.fetchall()
# Фильтруем только те фильмы, которых нет в таблице эмбеддингов
for movie_id, movie_data, prepared_json in all_movies:
if movie_id not in existing_ids:
prepared_string = f"Название: {prepared_json['Название']}\nГод: {prepared_json['Год']}\nЖанры: {prepared_json['Жанры']}\nОписание: {prepared_json['Описание']}"
movies_to_process.append((movie_id, movie_data, prepared_string))
logging.info(f"Найдено {len(movies_to_process)} фильмов для обработки.")
except Exception as e:
logging.error(f"Ошибка при получении списка фильмов для обработки: {e}")
finally:
conn.close()
return movies_to_process
def get_embedding_from_db(conn, table_name, crc32_column, crc32_value, model_name):
"""Получает эмбеддинг из базы данных."""
try:
with conn.cursor() as cur:
cur.execute(f"SELECT embedding FROM \"{table_name}\" WHERE \"{crc32_column}\" = %s AND model_name = %s",
(crc32_value, model_name))
result = cur.fetchone()
if result and result[0]:
# Нормализуем эмбеддинг после извлечения из БД
return normalize(np.array(result[0]).reshape(1, -1))[0]
except Exception as e:
logging.error(f"Ошибка при получении эмбеддинга из БД: {e}")
return None
def insert_embedding(conn, table_name, movie_id, embedding_crc32, string_crc32, embedding):
"""Вставляет эмбеддинг в базу данных."""
try:
# Нормализуем эмбеддинг перед сохранением
normalized_embedding = normalize(embedding.reshape(1, -1))[0]
with conn.cursor() as cur:
cur.execute(f"""
INSERT INTO "{table_name}"
(movie_id, embedding_crc32, string_crc32, model_name, embedding)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (movie_id) DO NOTHING
""", (movie_id, embedding_crc32, string_crc32, model_name, normalized_embedding.tolist()))
conn.commit()
return True
except Exception as e:
logging.error(f"Ошибка при вставке эмбеддинга: {e}")
conn.rollback()
return False
def process_batch(batch):
"""Обрабатывает пакет фильмов, создавая для них эмбеддинги."""
conn = get_db_connection()
if conn is None:
return
try:
for movie_id, movie_data, prepared_string in batch:
string_crc32 = calculate_crc32(prepared_string)
# Проверяем существующий эмбеддинг
existing_embedding = get_embedding_from_db(conn, embeddings_table, "string_crc32", string_crc32, model_name)
if existing_embedding is None:
embedding = encode_string(prepared_string)
embedding_crc32 = calculate_crc32(str(embedding.tolist()))
if insert_embedding(conn, embeddings_table, movie_id, embedding_crc32, string_crc32, embedding):
logging.info(f"Сохранен эмбеддинг для '{movie_data['name']}' (ID: {movie_id})")
else:
logging.error(f"Ошибка сохранения эмбеддинга для '{movie_data['name']}' (ID: {movie_id})")
else:
logging.info(f"Эмбеддинг для '{movie_data['name']}' (ID: {movie_id}) уже существует")
except Exception as e:
logging.error(f"Ошибка при обработке пакета фильмов: {e}")
finally:
conn.close()
def process_movies():
"""Обрабатывает фильмы, создавая для них эмбеддинги."""
global processing_complete
logging.info("Начало обработки фильмов.")
# Получаем список фильмов, которые нужно обработать
movies_to_process = get_movies_without_embeddings()
if not movies_to_process:
logging.info("Все фильмы уже обработаны.")
processing_complete = True
return
# Добавляем фильмы в очередь
for movie in movies_to_process:
movies_queue.put(movie)
with ThreadPoolExecutor(max_workers=num_threads) as executor:
try:
while not movies_queue.empty():
if search_in_progress:
time.sleep(1)
continue
batch = []
while not movies_queue.empty() and len(batch) < batch_size:
try:
movie = movies_queue.get_nowait()
batch.append(movie)
except queue.Empty:
break
if not batch:
break
executor.submit(process_batch, batch)
logging.info(f"Отправлен на обработку пакет из {len(batch)} фильмов.")
except Exception as e:
logging.error(f"Ошибка при обработке фильмов: {e}")
processing_complete = True
logging.info("Обработка фильмов завершена")
def get_movie_data_from_db(conn, movie_ids):
"""Получает данные фильмов из таблицы Movies по списку ID."""
movie_data_dict = {}
try:
with conn.cursor() as cur:
cur.execute(f"""
SELECT id, data,
jsonb_build_object(
'Название', data->>'name',
'Год', data->>'year',
'Жанры', (SELECT string_agg(genre->>'name', ', ') FROM jsonb_array_elements(data->'genres') AS genre),
'Описание', COALESCE(data->>'description', '')
) AS prepared_json
FROM "{movies_table}"
WHERE id IN %s
""", (tuple(movie_ids),))
for movie_id, movie_data, prepared_json in cur.fetchall():
prepared_string = f"Название: {prepared_json['Название']}\nГод: {prepared_json['Год']}\nЖанры: {prepared_json['Жанры']}\nОписание: {prepared_json['Описание']}"
movie_data_dict[movie_id] = (movie_data, prepared_string)
except Exception as e:
logging.error(f"Ошибка при получении данных фильмов из БД: {e}")
return movie_data_dict
def rerank_with_api(query, results, top_k):
"""Переранжирует результаты с помощью Jina AI Reranker API."""
logging.info(f"Начало переранжирования для запроса: '{query}'")
# Получаем данные фильмов из БД
conn = get_db_connection()
movie_ids = [movie_id for movie_id, _ in results]
movie_data_dict = get_movie_data_from_db(conn, movie_ids)
conn.close()
documents = []
for movie_id, _ in results:
movie_data, prepared_string = movie_data_dict.get(movie_id, (None, None))
if movie_data:
documents.append(prepared_string)
else:
logging.warning(f"Данные для фильма с ID {movie_id} не найдены в БД.")
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {JINA_API_KEY}'
}
data = {
"model": JINA_RERANKER_MODEL,
"query": query,
"top_n": top_k,
"documents": documents
}
logging.info(f"Отправка данных на реранжировку (documents count): {len(data['documents'])}")
try:
response = requests.post(JINA_API_URL, headers=headers, json=data)
response.raise_for_status()
result = response.json()
logging.info(f"Ответ от API реранжировщика получен.")
reranked_results = []
if 'results' in result:
for item in result['results']:
index = item['index']
movie_id = results[index][0]
reranked_results.append((movie_id, item['relevance_score']))
else:
logging.warning("Ответ от API не содержит ключа 'results'.")
logging.info("Переранжирование завершено.")
return reranked_results
except requests.exceptions.RequestException as e:
logging.error(f"Ошибка при запросе к API реранжировщика: {e}")
return []
def search_movies(query, top_k=25):
"""Выполняет поиск фильмов по запросу."""
global search_in_progress
search_in_progress = True
start_time = time.time()
try:
conn = get_db_connection()
if conn is None:
return "<p>Ошибка подключения к базе данных</p>"
query_crc32 = calculate_crc32(query)
query_embedding = get_embedding_from_db(conn, query_cache_table, "query_crc32", query_crc32, model_name)
if query_embedding is None:
query_embedding = encode_string(query)
try:
with conn.cursor() as cur:
cur.execute(f"""
INSERT INTO "{query_cache_table}" (query_crc32, query, model_name, embedding)
VALUES (%s, %s, %s, %s)
ON CONFLICT (query_crc32) DO NOTHING
""", (query_crc32, query, model_name, query_embedding.tolist()))
conn.commit()
logging.info(f"Сохранен новый эмбеддинг запроса: {query}")
except Exception as e:
logging.error(f"Ошибка при сохранении эмбеддинга запроса: {e}")
conn.rollback()
# Используем косинусное расстояние для поиска
try:
with conn.cursor() as cur:
cur.execute(f"""
WITH query_embedding AS (
SELECT embedding
FROM "{query_cache_table}"
WHERE query_crc32 = %s
)
SELECT m.movie_id, 1 - (m.embedding <=> (SELECT embedding FROM query_embedding)) as similarity
FROM "{embeddings_table}" m, query_embedding
ORDER BY similarity DESC
LIMIT %s
""", (query_crc32, int(top_k * 2)))
results = cur.fetchall()
logging.info(f"Найдено {len(results)} предварительных результатов поиска.")
except Exception as e:
logging.error(f"Ошибка при выполнении поискового запроса: {e}")
results = []
finally:
conn.close()
# Переранжируем результаты с помощью API
reranked_results = rerank_with_api(query, results, top_k)
conn = get_db_connection()
movie_ids = [movie_id for movie_id, _ in reranked_results]
movie_data_dict = get_movie_data_from_db(conn, movie_ids)
conn.close()
output = ""
for movie_id, score in reranked_results:
# Находим данные фильма
movie_data, _ = movie_data_dict.get(movie_id, (None, None))
if movie_data:
output += f"<h3>{movie_data['name']} ({movie_data['year']})</h3>\n"
output += f"<p><strong>Жанры:</strong> {', '.join([genre['name'] for genre in movie_data['genres']])}</p>\n"
output += f"<p><strong>Описание:</strong> {movie_data.get('description', '')}</p>\n"
output += f"<p><strong>Релевантность (reranker score):</strong> {score:.4f}</p>\n"
output += "<hr>\n"
else:
logging.warning(f"Данные для фильма с ID {movie_id} не найдены в БД.")
search_time = time.time() - start_time
logging.info(f"Поиск выполнен за {search_time:.2f} секунд.")
return f"<p>Время поиска: {search_time:.2f} сек</p>{output}"
except Exception as e:
logging.error(f"Ошибка при выполнении поиска: {e}")
return "<p>Произошла ошибка при выполнении поиска.</p>"
finally:
search_in_progress = False
# Запускаем обработку фильмов в отдельном потоке
processing_thread = threading.Thread(target=process_movies)
processing_thread.start()
# Создаем интерфейс Gradio
iface = gr.Interface(
fn=search_movies,
inputs=gr.Textbox(lines=2, placeholder="Введите запрос для поиска фильмов..."),
outputs=gr.HTML(label="Результаты поиска"),
title="Семантический поиск фильмов",
description="Введите описание фильма, который вы ищете, и система найдет наиболее похожие фильмы."
)
# Запускаем интерфейс
iface.launch()