|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer, util |
|
import json |
|
import os |
|
import time |
|
import threading |
|
import queue |
|
import torch |
|
|
|
|
|
model_name_kalm = "HIT-TMG/KaLM-embedding-multilingual-mini-instruct-v1.5" |
|
model_kalm = SentenceTransformer(model_name_kalm) |
|
|
|
model_name_bge = "BAAI/bge-m3" |
|
model_bge = SentenceTransformer(model_name_bge) |
|
|
|
|
|
embeddings_file_kalm = f"movie_embeddings_{model_name_kalm.replace('/', '_')}.json" |
|
query_embeddings_file_kalm = f"query_embeddings_{model_name_kalm.replace('/', '_')}.json" |
|
|
|
embeddings_file_bge = f"movie_embeddings_{model_name_bge.replace('/', '_')}.json" |
|
query_embeddings_file_bge = f"query_embeddings_{model_name_bge.replace('/', '_')}.json" |
|
|
|
|
|
try: |
|
with open("movies.json", "r", encoding="utf-8") as f: |
|
movies_data = json.load(f) |
|
except FileNotFoundError: |
|
print("Ошибка: Файл movies.json не найден.") |
|
movies_data = [] |
|
|
|
|
|
if os.path.exists(embeddings_file_kalm): |
|
with open(embeddings_file_kalm, "r", encoding="utf-8") as f: |
|
movie_embeddings_kalm = json.load(f) |
|
print("Загружены эмбеддинги фильмов для KaLM из файла.") |
|
else: |
|
movie_embeddings_kalm = {} |
|
|
|
|
|
if os.path.exists(query_embeddings_file_kalm): |
|
with open(query_embeddings_file_kalm, "r", encoding="utf-8") as f: |
|
query_embeddings_kalm = json.load(f) |
|
print("Загружены эмбеддинги запросов для KaLM из файла.") |
|
else: |
|
query_embeddings_kalm = {} |
|
|
|
|
|
if os.path.exists(embeddings_file_bge): |
|
with open(embeddings_file_bge, "r", encoding="utf-8") as f: |
|
movie_embeddings_bge = json.load(f) |
|
print("Загружены эмбеддинги фильмов для BGE-M3 из файла.") |
|
else: |
|
movie_embeddings_bge = {} |
|
|
|
|
|
if os.path.exists(query_embeddings_file_bge): |
|
with open(query_embeddings_file_bge, "r", encoding="utf-8") as f: |
|
query_embeddings_bge = json.load(f) |
|
print("Загружены эмбеддинги запросов для BGE-M3 из файла.") |
|
else: |
|
query_embeddings_bge = {} |
|
|
|
|
|
movies_queue_kalm = queue.Queue() |
|
movies_queue_bge = queue.Queue() |
|
|
|
for movie in movies_data: |
|
if movie["name"] not in movie_embeddings_kalm: |
|
movies_queue_kalm.put(movie) |
|
if movie["name"] not in movie_embeddings_bge: |
|
movies_queue_bge.put(movie) |
|
|
|
|
|
processing_complete_kalm = False |
|
processing_complete_bge = False |
|
|
|
|
|
search_in_progress_kalm = False |
|
search_in_progress_bge = False |
|
|
|
|
|
movie_embeddings_lock_kalm = threading.Lock() |
|
movie_embeddings_lock_bge = threading.Lock() |
|
|
|
|
|
batch_size = 32 |
|
|
|
|
|
query_prompt_kalm = "Инструкция: Найди релевантные фильмы по запросу. \n Запрос: " |
|
|
|
def encode_string(text, model, prompt=None): |
|
"""Кодирует строку в эмбеддинг с использованием инструкции, если она задана.""" |
|
if prompt: |
|
return model.encode(text, prompt=prompt, convert_to_tensor=True, normalize_embeddings=True, batch_size=batch_size) |
|
else: |
|
return model.encode(text, convert_to_tensor=True, normalize_embeddings=True, batch_size=batch_size) |
|
|
|
def process_movies(model, embeddings_file, movie_embeddings, movies_queue, lock, model_name): |
|
""" |
|
Обрабатывает фильмы из очереди, создавая для них эмбеддинги. |
|
""" |
|
global processing_complete_kalm, processing_complete_bge |
|
|
|
while True: |
|
batch = [] |
|
while not movies_queue.empty() and len(batch) < batch_size: |
|
try: |
|
movie = movies_queue.get(timeout=1) |
|
batch.append(movie) |
|
except queue.Empty: |
|
break |
|
|
|
if not batch: |
|
print(f"Очередь фильмов для {model_name} пуста.") |
|
if model_name == model_name_kalm: |
|
processing_complete_kalm = True |
|
elif model_name == model_name_bge: |
|
processing_complete_bge = True |
|
break |
|
|
|
titles = [movie["name"] for movie in batch] |
|
embedding_strings = [ |
|
f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genresList']}\nОписание: {movie['description']}" |
|
for movie in batch |
|
] |
|
|
|
print(f"Создаются эмбеддинги для фильмов ({model_name}): {', '.join(titles)}...") |
|
embeddings = model.encode(embedding_strings, convert_to_tensor=True, batch_size=batch_size, normalize_embeddings=True).tolist() |
|
|
|
with lock: |
|
for title, embedding in zip(titles, embeddings): |
|
movie_embeddings[title] = embedding |
|
|
|
with open(embeddings_file, "w", encoding="utf-8") as f: |
|
json.dump(movie_embeddings, f, ensure_ascii=False, indent=4) |
|
print(f"Эмбеддинги для фильмов ({model_name}): {', '.join(titles)} созданы и сохранены.") |
|
|
|
print(f"Обработка фильмов для {model_name} завершена.") |
|
|
|
def get_query_embedding(query, model, query_embeddings, query_embeddings_file, prompt=None): |
|
""" |
|
Возвращает эмбеддинг для запроса с инструкцией. |
|
Если эмбеддинг уже создан, возвращает его из словаря. |
|
Иначе создает эмбеддинг, сохраняет его и возвращает. |
|
""" |
|
if query in query_embeddings: |
|
print(f"Эмбеддинг для запроса '{query}' уже существует.") |
|
return query_embeddings[query] |
|
else: |
|
print(f"Создается эмбеддинг для запроса '{query}'...") |
|
embedding = encode_string(query, model, prompt=prompt).tolist() |
|
query_embeddings[query] = embedding |
|
|
|
with open(query_embeddings_file, "w", encoding="utf-8") as f: |
|
json.dump(query_embeddings, f, ensure_ascii=False, indent=4) |
|
print(f"Эмбеддинг для запроса '{query}' создан и сохранен.") |
|
return embedding |
|
|
|
def search_movies(query, model, movie_embeddings, movies_data, query_embeddings, query_embeddings_file, top_k=10, query_prompt=None): |
|
""" |
|
Ищет наиболее похожие фильмы по запросу с использованием инструкции. |
|
|
|
Args: |
|
query: Текстовый запрос. |
|
model: Модель для эмбеддингов. |
|
movie_embeddings: Словарь с эмбеддингами фильмов. |
|
movies_data: Данные о фильмах. |
|
top_k: Количество возвращаемых результатов. |
|
query_prompt: Инструкция для запроса (для KaLM). |
|
|
|
Returns: |
|
Строку с результатами поиска в формате HTML. |
|
""" |
|
global search_in_progress_kalm, search_in_progress_bge |
|
|
|
if model == model_kalm: |
|
search_in_progress_kalm = True |
|
elif model == model_bge: |
|
search_in_progress_bge = True |
|
|
|
start_time = time.time() |
|
print(f"\n\033[1mПоиск по запросу: '{query}'\033[0m") |
|
|
|
print(f"Начало создания эмбеддинга для запроса: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
query_embedding_tensor = torch.tensor(get_query_embedding(query, model, query_embeddings, query_embeddings_file, prompt=query_prompt)) |
|
print(f"Окончание создания эмбеддинга для запроса: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
if model == model_kalm: |
|
with movie_embeddings_lock_kalm: |
|
current_movie_embeddings = movie_embeddings.copy() |
|
elif model == model_bge: |
|
with movie_embeddings_lock_bge: |
|
current_movie_embeddings = movie_embeddings.copy() |
|
|
|
if not current_movie_embeddings: |
|
if model == model_kalm: |
|
search_in_progress_kalm = False |
|
elif model == model_bge: |
|
search_in_progress_bge = False |
|
return "<p>Пока что нет обработанных фильмов. Попробуйте позже.</p>" |
|
|
|
|
|
movie_titles = list(current_movie_embeddings.keys()) |
|
movie_embeddings_tensor = torch.tensor(list(current_movie_embeddings.values())) |
|
|
|
print(f"Начало поиска похожих фильмов: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
hits = util.semantic_search(query_embedding_tensor, movie_embeddings_tensor, top_k=top_k)[0] |
|
print(f"Окончание поиска похожих фильмов: {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
results_html = "" |
|
for hit in hits: |
|
title = movie_titles[hit['corpus_id']] |
|
score = hit['score'] |
|
|
|
for movie in movies_data: |
|
if movie["name"] == title: |
|
description = movie["description"] |
|
year = movie["year"] |
|
genres = movie["genresList"] |
|
break |
|
|
|
results_html += f"<h3><b>{title} ({year})</b></h3>" |
|
results_html += f"<p><b>Жанры:</b> {genres}</p>" |
|
results_html += f"<p><b>Описание:</b> {description}</p>" |
|
results_html += f"<p><b>Сходство:</b> {score:.4f}</p>" |
|
results_html += "<hr>" |
|
|
|
end_time = time.time() |
|
execution_time = end_time - start_time |
|
print(f"Поиск завершен за {execution_time:.4f} секунд.") |
|
|
|
if model == model_kalm: |
|
search_in_progress_kalm = False |
|
elif model == model_bge: |
|
search_in_progress_bge = False |
|
|
|
return results_html |
|
|
|
|
|
processing_thread_kalm = threading.Thread(target=process_movies, args=(model_kalm, embeddings_file_kalm, movie_embeddings_kalm, movies_queue_kalm, movie_embeddings_lock_kalm, model_name_kalm)) |
|
processing_thread_bge = threading.Thread(target=process_movies, args=(model_bge, embeddings_file_bge, movie_embeddings_bge, movies_queue_bge, movie_embeddings_lock_bge, model_name_bge)) |
|
|
|
|
|
processing_thread_kalm.start() |
|
processing_thread_bge.start() |
|
|
|
def search_with_kalm(query): |
|
return search_movies(query, model_kalm, movie_embeddings_kalm, movies_data, query_embeddings_kalm, query_embeddings_file_kalm, top_k=10, query_prompt=query_prompt_kalm) |
|
|
|
def search_with_bge(query): |
|
return search_movies(query, model_bge, movie_embeddings_bge, movies_data, query_embeddings_bge, query_embeddings_file_bge, top_k=10) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Tab("KaLM"): |
|
text_input_kalm = gr.Textbox(label="Введите запрос для KaLM") |
|
text_output_kalm = gr.HTML() |
|
text_button_kalm = gr.Button("Поиск с KaLM") |
|
with gr.Tab("BGE-M3"): |
|
text_input_bge = gr.Textbox(label="Введите запрос для BGE-M3") |
|
text_output_bge = gr.HTML() |
|
text_button_bge = gr.Button("Поиск с BGE-M3") |
|
|
|
text_button_kalm.click(search_with_kalm, inputs=text_input_kalm, outputs=text_output_kalm) |
|
text_button_bge.click(search_with_bge, inputs=text_input_bge, outputs=text_output_bge) |
|
|
|
demo.queue() |
|
demo.launch() |