|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer, util |
|
import os |
|
import time |
|
import threading |
|
import queue |
|
import torch |
|
import psycopg2 |
|
import zlib |
|
from urllib.parse import urlparse |
|
|
|
|
|
DATABASE_URL = os.environ.get("DB_URL") |
|
if DATABASE_URL is None: |
|
raise ValueError("DATABASE_URL environment variable not set.") |
|
|
|
parsed_url = urlparse(DATABASE_URL) |
|
db_params = { |
|
"host": parsed_url.hostname, |
|
"port": parsed_url.port, |
|
"database": parsed_url.path.lstrip("/"), |
|
"user": parsed_url.username, |
|
"password": parsed_url.password, |
|
"sslmode": "require" |
|
} |
|
|
|
|
|
model_name = "BAAI/bge-m3" |
|
model = SentenceTransformer(model_name) |
|
|
|
|
|
embeddings_table = "movie_embeddings" |
|
query_cache_table = "query_cache" |
|
|
|
|
|
MAX_CACHE_SIZE = 50 * 1024 * 1024 |
|
|
|
|
|
try: |
|
import json |
|
with open("movies.json", "r", encoding="utf-8") as f: |
|
movies_data = json.load(f) |
|
except FileNotFoundError: |
|
print("Ошибка: Файл movies.json не найден.") |
|
movies_data = [] |
|
|
|
|
|
movies_queue = queue.Queue() |
|
for movie in movies_data: |
|
movies_queue.put(movie) |
|
|
|
|
|
processing_complete = False |
|
|
|
search_in_progress = False |
|
|
|
|
|
db_lock = threading.Lock() |
|
|
|
|
|
batch_size = 32 |
|
|
|
def get_db_connection(): |
|
"""Устанавливает соединение с базой данных.""" |
|
try: |
|
conn = psycopg2.connect(**db_params) |
|
return conn |
|
except Exception as e: |
|
print(f"Ошибка подключения к базе данных: {e}") |
|
return None |
|
|
|
def create_embeddings_table(): |
|
"""Создает таблицу для хранения эмбеддингов фильмов, если она не существует.""" |
|
conn = get_db_connection() |
|
if conn is None: |
|
return |
|
|
|
with conn.cursor() as cur: |
|
cur.execute(f""" |
|
CREATE TABLE IF NOT EXISTS {embeddings_table} ( |
|
movie_id INTEGER, |
|
embedding_crc32 BIGINT PRIMARY KEY, |
|
string_crc32 BIGINT, |
|
model_name TEXT, |
|
embedding vector(1024) |
|
); |
|
""") |
|
conn.commit() |
|
conn.close() |
|
|
|
def create_query_cache_table(): |
|
"""Создает таблицу для кэширования эмбеддингов запросов, если она не существует.""" |
|
conn = get_db_connection() |
|
if conn is None: |
|
return |
|
|
|
with conn.cursor() as cur: |
|
cur.execute(f""" |
|
CREATE TABLE IF NOT EXISTS {query_cache_table} ( |
|
query_crc32 BIGINT PRIMARY KEY, |
|
query TEXT, |
|
model_name TEXT, |
|
embedding vector(1024), |
|
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP |
|
); |
|
CREATE INDEX IF NOT EXISTS idx_query_crc32 ON {query_cache_table} (query_crc32); |
|
CREATE INDEX IF NOT EXISTS idx_created_at ON {query_cache_table} (created_at); |
|
""") |
|
conn.commit() |
|
conn.close() |
|
|
|
def create_trigger_function(): |
|
"""Создает функцию и триггер для автоматического удаления старых записей из таблицы кэша запросов""" |
|
conn = get_db_connection() |
|
if conn: |
|
with conn.cursor() as cur: |
|
cur.execute(f""" |
|
CREATE OR REPLACE FUNCTION manage_query_cache_size() |
|
RETURNS TRIGGER AS $$ |
|
DECLARE |
|
table_size BIGINT; |
|
row_to_delete RECORD; |
|
BEGIN |
|
SELECT pg_total_relation_size('{query_cache_table}') INTO table_size; |
|
IF table_size > {MAX_CACHE_SIZE} THEN |
|
FOR row_to_delete IN |
|
SELECT query_crc32 |
|
FROM {query_cache_table} |
|
ORDER BY created_at ASC |
|
LOOP |
|
DELETE FROM {query_cache_table} WHERE query_crc32 = row_to_delete.query_crc32; |
|
SELECT pg_total_relation_size('{query_cache_table}') INTO table_size; |
|
EXIT WHEN table_size <= {MAX_CACHE_SIZE}; |
|
END LOOP; |
|
END IF; |
|
RETURN NEW; |
|
END; |
|
$$ LANGUAGE plpgsql; |
|
|
|
CREATE OR REPLACE TRIGGER trg_manage_query_cache_size |
|
AFTER INSERT ON {query_cache_table} |
|
FOR EACH ROW |
|
EXECUTE PROCEDURE manage_query_cache_size(); |
|
""") |
|
conn.commit() |
|
conn.close() |
|
|
|
|
|
create_embeddings_table() |
|
create_query_cache_table() |
|
create_trigger_function() |
|
|
|
def calculate_crc32(text): |
|
"""Вычисляет CRC32 для строки.""" |
|
return zlib.crc32(text.encode('utf-8')) & 0xFFFFFFFF |
|
|
|
def encode_string(text): |
|
"""Кодирует строку в эмбеддинг.""" |
|
return model.encode(text, convert_to_tensor=True, normalize_embeddings=True) |
|
|
|
def insert_embedding(conn, movie_id, embedding_string, model_name, embedding): |
|
"""Вставляет эмбеддинг фильма в базу данных.""" |
|
embedding_crc32 = calculate_crc32(str(embedding.tolist())) |
|
string_crc32 = calculate_crc32(embedding_string) |
|
with conn.cursor() as cur: |
|
try: |
|
cur.execute( |
|
f""" |
|
INSERT INTO {embeddings_table} (movie_id, embedding_crc32, string_crc32, model_name, embedding) |
|
VALUES (%s, %s, %s, %s, %s) |
|
ON CONFLICT (embedding_crc32) DO NOTHING; |
|
""", |
|
(movie_id, embedding_crc32, string_crc32, model_name, embedding.tolist()) |
|
) |
|
conn.commit() |
|
return True |
|
except Exception as e: |
|
print(f"Ошибка при вставке эмбеддинга фильма: {e}") |
|
conn.rollback() |
|
return False |
|
|
|
def insert_query_embedding(conn, query, model_name, embedding): |
|
"""Вставляет эмбеддинг запроса в таблицу кэша.""" |
|
query_crc32 = calculate_crc32(query) |
|
with conn.cursor() as cur: |
|
try: |
|
cur.execute( |
|
f""" |
|
INSERT INTO {query_cache_table} (query_crc32, query, model_name, embedding) |
|
VALUES (%s, %s, %s, %s) |
|
ON CONFLICT (query_crc32) DO UPDATE SET created_at = DEFAULT; |
|
""", |
|
(query_crc32, query, model_name, embedding.tolist()) |
|
) |
|
conn.commit() |
|
print(f"Эмбеддинг для запроса '{query}' сохранен в кэше.") |
|
return True |
|
except Exception as e: |
|
print(f"Ошибка при вставке эмбеддинга запроса: {e}") |
|
conn.rollback() |
|
return False |
|
|
|
def get_movie_embeddings(conn): |
|
"""Загружает все эмбеддинги фильмов из базы данных.""" |
|
movie_embeddings = {} |
|
with conn.cursor() as cur: |
|
cur.execute(f"SELECT movie_id, embedding FROM {embeddings_table}") |
|
rows = cur.fetchall() |
|
for row in rows: |
|
movie_id, embedding = row |
|
|
|
for movie in movies_data: |
|
if movie['id'] == movie_id: |
|
title = movie["name"] |
|
movie_embeddings[title] = torch.tensor(embedding) |
|
break |
|
return movie_embeddings |
|
|
|
def process_movies(): |
|
""" |
|
Обрабатывает фильмы из очереди, создавая для них эмбеддинги и сохраняя их в базу данных. |
|
""" |
|
global processing_complete |
|
conn = get_db_connection() |
|
if conn is None: |
|
processing_complete = True |
|
return |
|
|
|
while True: |
|
if search_in_progress: |
|
time.sleep(1) |
|
continue |
|
|
|
batch = [] |
|
while not movies_queue.empty() and len(batch) < batch_size: |
|
try: |
|
movie = movies_queue.get(timeout=1) |
|
batch.append(movie) |
|
except queue.Empty: |
|
break |
|
|
|
if not batch: |
|
print("Очередь фильмов пуста.") |
|
processing_complete = True |
|
break |
|
|
|
titles = [movie["name"] for movie in batch] |
|
embedding_strings = [ |
|
f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}" |
|
for movie in batch |
|
] |
|
|
|
print(f"Создаются эмбеддинги для фильмов: {', '.join(titles)}...") |
|
embeddings = model.encode(embedding_strings, convert_to_tensor=True, batch_size=batch_size, normalize_embeddings=True) |
|
|
|
with db_lock: |
|
for movie, embedding, embedding_string in zip(batch, embeddings, embedding_strings): |
|
if insert_embedding(conn, movie['id'], embedding_string, model_name, embedding): |
|
print(f"Эмбеддинг для фильма '{movie['name']}' сохранен в базе данных.") |
|
else: |
|
print(f"Ошибка сохранения эмбеддинга для фильма '{movie['name']}'.") |
|
|
|
conn.close() |
|
print("Обработка фильмов завершена.") |
|
|
|
def get_query_embedding_from_db(conn, query): |
|
""" |
|
Пытается получить эмбеддинг запроса из базы данных по CRC32. |
|
Возвращает эмбеддинг, если найден, иначе None. |
|
""" |
|
query_crc32 = calculate_crc32(query) |
|
with conn.cursor() as cur: |
|
cur.execute(f"SELECT embedding FROM {query_cache_table} WHERE query_crc32 = %s AND model_name = %s", (query_crc32, model_name)) |
|
result = cur.fetchone() |
|
if result: |
|
print(f"Эмбеддинг для запроса '{query}' найден в кэше.") |
|
return torch.tensor(result[0]) |
|
else: |
|
return None |
|
|
|
def search_movies(query, top_k=10): |
|
""" |
|
Ищет наиболее похожие фильмы по запросу. |
|
|
|
Args: |
|
query: Текстовый запрос. |
|
top_k: Количество возвращаемых результатов. |
|
|
|
Returns: |
|
Строку с результатами поиска в формате HTML. |
|
""" |
|
global search_in_progress |
|
search_in_progress = True |
|
start_time = time.time() |
|
print(f"\n\033[1mПоиск по запросу: '{query}'\033[0m") |
|
|
|
conn = get_db_connection() |
|
if conn is None: |
|
search_in_progress = False |
|
return "<p>Ошибка подключения к базе данных.</p>" |
|
|
|
print(f"Начало создания эмбеддинга для запроса: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
query_embedding_tensor = get_query_embedding_from_db(conn, query) |
|
|
|
if query_embedding_tensor is None: |
|
query_embedding_tensor = encode_string(query) |
|
|
|
insert_query_embedding(conn, query, model_name, query_embedding_tensor) |
|
print(f"Окончание создания эмбеддинга для запроса: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
with db_lock: |
|
current_movie_embeddings = get_movie_embeddings(conn) |
|
|
|
conn.close() |
|
|
|
if not current_movie_embeddings: |
|
search_in_progress = False |
|
return "<p>Пока что нет обработанных фильмов. Попробуйте позже.</p>" |
|
|
|
|
|
movie_titles = list(current_movie_embeddings.keys()) |
|
movie_embeddings_tensor = torch.stack(list(current_movie_embeddings.values())) |
|
|
|
print(f"Начало поиска похожих фильмов: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
hits = util.semantic_search(query_embedding_tensor, movie_embeddings_tensor, top_k=top_k)[0] |
|
print(f"Окончание поиска похожих фильмов: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
results_html = "" |
|
for hit in hits: |
|
title = movie_titles[hit['corpus_id']] |
|
score = hit['score'] |
|
|
|
for movie in movies_data: |
|
if movie["name"] == title: |
|
description = movie["description"] |
|
year = movie["year"] |
|
genres = movie["genresList"] |
|
break |
|
|
|
results_html += f"<h3><b>{title} ({year})</b></h3>" |
|
results_html += f"<p><b>Жанры:</b> {genres}</p>" |
|
results_html += f"<p><b>Описание:</b> {description}</p>" |
|
results_html += f"<p><b>Сходство:</b> {score:.4f}</p>" |
|
results_html += "<hr>" |
|
|
|
end_time = time.time() |
|
execution_time = end_time - start_time |
|
print(f"Поиск завершен за {execution_time:.4f} секунд.") |
|
search_in_progress = False |
|
return results_html |
|
|
|
|
|
processing_thread = threading.Thread(target=process_movies) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=search_movies, |
|
inputs=gr.Textbox(label="Введите запрос:"), |
|
outputs=gr.HTML(label="Результаты поиска:"), |
|
title="Поиск фильмов по описанию", |
|
description="Введите запрос, и система найдет наиболее похожие фильмы по их описаниям.", |
|
examples=[ |
|
["Фильм про ограбление"], |
|
["Комедия 2019 года"], |
|
["Фантастика про космос"], |
|
], |
|
) |
|
|
|
|
|
processing_thread.start() |
|
|
|
|
|
iface.queue() |
|
iface.launch() |