opex792's picture
Update app.py
37cf39e verified
raw
history blame
15.6 kB
import gradio as gr
from sentence_transformers import SentenceTransformer, util
import os
import time
import threading
import queue
import torch
import psycopg2
import zlib
import numpy as np
from urllib.parse import urlparse
import logging
from sklearn.preprocessing import normalize
# Настройка логирования
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Настройки базы данных PostgreSQL
DATABASE_URL = os.environ.get("DATABASE_URL")
if DATABASE_URL is None:
raise ValueError("DATABASE_URL environment variable not set.")
parsed_url = urlparse(DATABASE_URL)
db_params = {
"host": parsed_url.hostname,
"port": parsed_url.port,
"database": parsed_url.path.lstrip("/"),
"user": parsed_url.username,
"password": parsed_url.password,
"sslmode": "require"
}
# Загружаем модель
model_name = "BAAI/bge-m3"
logging.info(f"Загрузка модели {model_name}...")
model = SentenceTransformer(model_name)
logging.info("Модель загружена успешно.")
# Имена таблиц
embeddings_table = "movie_embeddings"
query_cache_table = "query_cache"
# Максимальный размер таблицы кэша запросов в байтах (50MB)
MAX_CACHE_SIZE = 50 * 1024 * 1024
# Загружаем данные из файла movies.json
try:
import json
with open("movies.json", "r", encoding="utf-8") as f:
movies_data = json.load(f)
logging.info(f"Загружено {len(movies_data)} фильмов из movies.json")
except FileNotFoundError:
logging.error("Ошибка: Файл movies.json не найден.")
movies_data = []
# Очередь для необработанных фильмов
movies_queue = queue.Queue()
# Флаг, указывающий, что обработка фильмов завершена
processing_complete = False
# Флаг, указывающий, что выполняется поиск
search_in_progress = False
# Блокировка для доступа к базе данных
db_lock = threading.Lock()
# Размер пакета для обработки эмбеддингов
batch_size = 32
def get_db_connection():
"""Устанавливает соединение с базой данных."""
try:
conn = psycopg2.connect(**db_params)
return conn
except Exception as e:
logging.error(f"Ошибка подключения к базе данных: {e}")
return None
def setup_database():
"""Настраивает базу данных: создает расширение, таблицы и индексы."""
conn = get_db_connection()
if conn is None:
return
try:
with conn.cursor() as cur:
# Создаем расширение pgvector если его нет
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# Удаляем существующие таблицы если они есть
cur.execute(f"DROP TABLE IF EXISTS {embeddings_table}, {query_cache_table};")
# Создаем таблицу для хранения эмбеддингов фильмов
cur.execute(f"""
CREATE TABLE {embeddings_table} (
movie_id INTEGER PRIMARY KEY,
embedding_crc32 BIGINT,
string_crc32 BIGINT,
model_name TEXT,
embedding vector(1024)
);
CREATE INDEX ON {embeddings_table} (string_crc32);
""")
# Создаем таблицу для кэширования запросов
cur.execute(f"""
CREATE TABLE {query_cache_table} (
query_crc32 BIGINT PRIMARY KEY,
query TEXT,
model_name TEXT,
embedding vector(1024),
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX ON {query_cache_table} (query_crc32);
CREATE INDEX ON {query_cache_table} (created_at);
""")
conn.commit()
logging.info("База данных успешно настроена.")
except Exception as e:
logging.error(f"Ошибка при настройке базы данных: {e}")
conn.rollback()
finally:
conn.close()
# Настраиваем базу данных при запуске
setup_database()
def calculate_crc32(text):
"""Вычисляет CRC32 для строки."""
return zlib.crc32(text.encode('utf-8')) & 0xFFFFFFFF
def encode_string(text):
"""Кодирует строку в эмбеддинг."""
embedding = model.encode(text, convert_to_tensor=True, normalize_embeddings=True)
return embedding.cpu().numpy()
def get_movies_without_embeddings():
"""Получает список фильмов, для которых нужно создать эмбеддинги."""
conn = get_db_connection()
if conn is None:
return []
movies_to_process = []
try:
with conn.cursor() as cur:
# Получаем список ID фильмов, которые уже есть в базе
cur.execute(f"SELECT movie_id FROM {embeddings_table}")
existing_ids = {row[0] for row in cur.fetchall()}
# Фильтруем только те фильмы, которых нет в базе
for movie in movies_data:
if movie['id'] not in existing_ids:
movies_to_process.append(movie)
logging.info(f"Найдено {len(movies_to_process)} фильмов для обработки.")
except Exception as e:
logging.error(f"Ошибка при получении списка фильмов для обработки: {e}")
finally:
conn.close()
return movies_to_process
def get_embedding_from_db(conn, table_name, crc32_column, crc32_value, model_name):
"""Получает эмбеддинг из базы данных."""
try:
with conn.cursor() as cur:
cur.execute(f"SELECT embedding FROM {table_name} WHERE {crc32_column} = %s AND model_name = %s",
(crc32_value, model_name))
result = cur.fetchone()
if result and result[0]:
# Нормализуем эмбеддинг после извлечения из БД
return normalize(np.array(result[0]).reshape(1, -1))[0]
except Exception as e:
logging.error(f"Ошибка при получении эмбеддинга из БД: {e}")
return None
def insert_embedding(conn, table_name, movie_id, embedding_crc32, string_crc32, embedding):
"""Вставляет эмбеддинг в базу данных."""
try:
# Нормализуем эмбеддинг перед сохранением
normalized_embedding = normalize(embedding.reshape(1, -1))[0]
with conn.cursor() as cur:
cur.execute(f"""
INSERT INTO {table_name}
(movie_id, embedding_crc32, string_crc32, model_name, embedding)
VALUES (%s, %s, %s, %s, %s)
ON CONFLICT (movie_id) DO NOTHING
""", (movie_id, embedding_crc32, string_crc32, model_name, normalized_embedding.tolist()))
conn.commit()
return True
except Exception as e:
logging.error(f"Ошибка при вставке эмбеддинга: {e}")
conn.rollback()
return False
def process_movies():
"""Обрабатывает фильмы, создавая для них эмбеддинги."""
global processing_complete
logging.info("Начало обработки фильмов.")
# Получаем список фильмов, которые нужно обработать
movies_to_process = get_movies_without_embeddings()
if not movies_to_process:
logging.info("Все фильмы уже обработаны.")
processing_complete = True
return
# Добавляем фильмы в очередь
for movie in movies_to_process:
movies_queue.put(movie)
conn = get_db_connection()
if conn is None:
processing_complete = True
return
try:
while not movies_queue.empty():
if search_in_progress:
time.sleep(1)
continue
batch = []
while not movies_queue.empty() and len(batch) < batch_size:
try:
movie = movies_queue.get_nowait()
batch.append(movie)
except queue.Empty:
break
if not batch:
break
logging.info(f"Обработка пакета из {len(batch)} фильмов...")
for movie in batch:
embedding_string = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}"
string_crc32 = calculate_crc32(embedding_string)
# Проверяем существующий эмбеддинг
existing_embedding = get_embedding_from_db(conn, embeddings_table, "string_crc32", string_crc32, model_name)
if existing_embedding is None:
embedding = encode_string(embedding_string)
embedding_crc32 = calculate_crc32(str(embedding.tolist()))
if insert_embedding(conn, embeddings_table, movie['id'], embedding_crc32, string_crc32, embedding):
logging.info(f"Сохранен эмбеддинг для '{movie['name']}'")
else:
logging.error(f"Ошибка сохранения эмбеддинга для '{movie['name']}'")
else:
logging.info(f"Эмбеддинг для '{movie['name']}' уже существует")
except Exception as e:
logging.error(f"Ошибка при обработке фильмов: {e}")
finally:
conn.close()
processing_complete = True
logging.info("Обработка фильмов завершена")
def get_movie_embeddings(conn):
"""Загружает все эмбеддинги фильмов из базы данных."""
movie_embeddings = {}
try:
with conn.cursor() as cur:
cur.execute(f"SELECT movie_id, embedding FROM {embeddings_table}")
for movie_id, embedding in cur.fetchall():
# Находим название фильма по ID
for movie in movies_data:
if movie['id'] == movie_id:
movie_embeddings[movie['name']] = normalize(np.array(embedding).reshape(1, -1))[0]
break
logging.info(f"Загружено {len(movie_embeddings)} эмбеддингов фильмов.")
except Exception as e:
logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
return movie_embeddings
def search_movies(query, top_k=10):
"""Выполняет поиск фильмов по запросу."""
global search_in_progress
search_in_progress = True
start_time = time.time()
try:
conn = get_db_connection()
if conn is None:
return "<p>Ошибка подключения к базе данных</p>"
query_crc32 = calculate_crc32(query)
query_embedding = get_embedding_from_db(conn, query_cache_table, "query_crc32", query_crc32, model_name)
if query_embedding is None:
query_embedding = encode_string(query)
try:
with conn.cursor() as cur:
cur.execute(f"""
INSERT INTO {query_cache_table} (query_crc32, query, model_name, embedding)
VALUES (%s, %s, %s, %s)
ON CONFLICT (query_crc32) DO NOTHING
""", (query_crc32, query, model_name, query_embedding.tolist()))
conn.commit()
logging.info(f"Сохранен новый эмбеддинг запроса: {query}")
except Exception as e:
logging.error(f"Ошибка при сохранении эмбеддинга запроса: {e}")
conn.rollback()
# Используем косинусное расстояние для поиска
try:
with conn.cursor() as cur:
cur.execute(f"""
WITH query_embedding AS (
SELECT embedding
FROM {query_cache_table}
WHERE query_crc32 = %s
)
SELECT m.movie_id, 1 - (m.embedding <=> (SELECT embedding FROM query_embedding)) as similarity
FROM {embeddings_table} m, query_embedding
ORDER BY similarity DESC
LIMIT %s
""", (query_crc32, top_k))
results = cur.fetchall()
logging.info(f"Найдено {len(results)} результатов поиска.")
except Exception as e:
logging.error(f"Ошибка при выполнении поискового запроса: {e}")
results = []
output = ""
for movie_id, similarity in results:
# Находим фильм по ID
movie = next((m for m in movies_data if m['id'] == movie_id), None)
if movie:
output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
output += f"<p><strong>Жанры:</strong> {movie['genresList']}</p>\n"
output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
output += f"<p><strong>Релевантность:</strong> {similarity:.4f}</p>\n"
output += "<hr>\n"
search_time = time.time() - start_time
logging.info(f"Поиск выполнен за {search_time:.2f} секунд.")
return f"<p>Время поиска: {search_time:.2f} сек</p>{output}"
except Exception as e:
logging.error(f"Ошибка при выполнении поиска: {e}")
return "<p>Произошла ошибка при выполнении поиска.</p>"
finally:
if conn:
conn.close()
search_in_progress = False
# Запускаем обработку фильмов в отдельном потоке
processing_thread = threading.Thread(target=process_movies)
processing_thread.start()
# Создаем интерфейс Gradio
iface = gr.Interface(
fn=search_movies,
inputs=gr.Textbox(lines=2, placeholder="Введите запрос для поиска фильмов..."),
outputs=gr.HTML(label="Результаты поиска"),
title="Семантический поиск фильмов",
description="Введите описание фильма, который вы ищете, и система найдет наиболее похожие фильмы."
)
# Запускаем интерфейс
iface.launch()