opex792 commited on
Commit
9a46a7b
·
verified ·
1 Parent(s): 0015f60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -21
app.py CHANGED
@@ -12,7 +12,7 @@ from urllib.parse import urlparse
12
  import logging
13
  from sklearn.preprocessing import normalize
14
  from concurrent.futures import ThreadPoolExecutor
15
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
16
 
17
  # Настройка логирования
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -38,13 +38,10 @@ logging.info(f"Загрузка модели {model_name}...")
38
  model = SentenceTransformer(model_name)
39
  logging.info("Модель загружена успешно.")
40
 
41
- # Загружаем модель реранкера
42
- reranker_name = 'BAAI/bge-reranker-v2-m3'
43
- logging.info(f"Загрузка модели реранкера {reranker_name}...")
44
- reranker_tokenizer = AutoTokenizer.from_pretrained(reranker_name)
45
- reranker_model = AutoModelForSequenceClassification.from_pretrained(reranker_name)
46
- reranker_model.eval()
47
- logging.info("Модель реранкера загружена успешно.")
48
 
49
  # Имена таблиц
50
  embeddings_table = "movie_embeddings"
@@ -81,6 +78,9 @@ batch_size = 32
81
  # Количество потоков для параллельной обработки
82
  num_threads = 5
83
 
 
 
 
84
  def get_db_connection():
85
  """Устанавливает соединение с базой данных."""
86
  try:
@@ -298,24 +298,84 @@ def get_movie_embeddings(conn):
298
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
299
  return movie_embeddings
300
 
301
- def rerank_results(query, results):
302
- """Переранжирует результаты поиска с помощью реранкера."""
303
- logging.info(f"Начало переранжирования для запроса: '{query}'")
304
- pairs = []
 
 
 
 
 
305
  movie_ids = []
306
- for i, (movie_id, _) in enumerate(results):
307
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
308
  if movie:
309
  movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genreslist']}\nОписание: {movie['description']}"
310
- pairs.append([query, movie_info])
311
  movie_ids.append(movie_id)
312
- logging.info(f"Обработка фильма для реранка {i+1}/{len(results)}: {movie['name']}")
313
 
314
- with torch.no_grad():
315
- inputs = reranker_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
316
- scores = reranker_model(**inputs, return_dict=True).logits.view(-1, ).float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
- reranked_results = sorted(zip(movie_ids, scores.tolist()), key=lambda x: x[1], reverse=True)
319
  logging.info("Переранжирование завершено.")
320
  return reranked_results
321
 
@@ -362,7 +422,7 @@ def search_movies(query, top_k=20):
362
  FROM {embeddings_table} m, query_embedding
363
  ORDER BY similarity DESC
364
  LIMIT %s
365
- """, (query_crc32, int(top_k * 2))) # Уменьшаем лимит до * 1.1
366
 
367
  results = cur.fetchall()
368
  logging.info(f"Найдено {len(results)} предварительных результатов поиска.")
@@ -381,7 +441,7 @@ def search_movies(query, top_k=20):
381
  output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
382
  output += f"<p><strong>Жанры:</strong> {movie['genreslist']}</p>\n"
383
  output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
384
- output += f"<p><strong>Релевантность (reranker score):</strong> {score:.4f}</p>\n"
385
  output += "<hr>\n"
386
 
387
  search_time = time.time() - start_time
 
12
  import logging
13
  from sklearn.preprocessing import normalize
14
  from concurrent.futures import ThreadPoolExecutor
15
+ import requests
16
 
17
  # Настройка логирования
18
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
38
  model = SentenceTransformer(model_name)
39
  logging.info("Модель загружена успешно.")
40
 
41
+ # Voyage AI API Key
42
+ VOYAGE_API_KEY = os.environ.get("VOYAGE_API_KEY")
43
+ if VOYAGE_API_KEY is None:
44
+ raise ValueError("VOYAGE_API_KEY environment variable not set.")
 
 
 
45
 
46
  # Имена таблиц
47
  embeddings_table = "movie_embeddings"
 
78
  # Количество потоков для параллельной обработки
79
  num_threads = 5
80
 
81
+ # Количество потоков для параллельного реранкинга
82
+ rerank_threads = 5 # Подберите оптимальное значение
83
+
84
  def get_db_connection():
85
  """Устанавливает соединение с базой данных."""
86
  try:
 
298
  logging.error(f"Ошибка при загрузке эмбеддингов фильмов: {e}")
299
  return movie_embeddings
300
 
301
+ def rerank_batch_voyage(query, batch):
302
+ """Переранжирует пакет результатов с помощью Voyage AI."""
303
+ url = "https://api.voyageai.com/v1/rerank"
304
+ headers = {
305
+ "Authorization": f"Bearer {VOYAGE_API_KEY}",
306
+ "content-type": "application/json"
307
+ }
308
+
309
+ documents = []
310
  movie_ids = []
311
+ for movie_id, _ in batch:
312
  movie = next((m for m in movies_data if m['id'] == movie_id), None)
313
  if movie:
314
  movie_info = f"Название: {movie['name']}\nГод: {movie['year']}\nЖанры: {movie['genreslist']}\nОписание: {movie['description']}"
315
+ documents.append(movie_info)
316
  movie_ids.append(movie_id)
 
317
 
318
+ payload = {
319
+ "query": query,
320
+ "documents": documents,
321
+ "model": "rerank-2", # Можно использовать rerank-2-lite для более быстрой, но менее точной модели
322
+ "return_documents": False,
323
+ "truncation": True
324
+ }
325
+
326
+ try:
327
+ response = requests.post(url, headers=headers, json=payload)
328
+ response.raise_for_status() # Проверка на ошибки HTTP
329
+ response_json = response.json()
330
+
331
+ reranked_results = []
332
+ for item in response_json['data']:
333
+ reranked_results.append((movie_ids[item['index']], item['relevance_score']))
334
+
335
+ logging.info(f"Voyage AI: Успешно переранжирован батч. Задействовано токенов: {response_json['usage']['total_tokens']}")
336
+ return reranked_results
337
+
338
+ except requests.exceptions.RequestException as e:
339
+ logging.error(f"Ошибка запроса к Voyage AI: {e}")
340
+ return []
341
+ except KeyError as e:
342
+ logging.error(f"Ошибка обработки ответа от Voyage AI: {e}. Полный ответ: {response_json}")
343
+ return []
344
+
345
+ def rerank_results(query, results):
346
+ """Переранжирует результаты поиска с помощью Voyage AI."""
347
+ logging.info(f"Начало переранжирования для запроса: '{query}'")
348
+ reranked_results = []
349
+
350
+ with ThreadPoolExecutor(max_workers=rerank_threads) as executor:
351
+ futures = []
352
+ batch = []
353
+ batch_num = 0
354
+ for i, result in enumerate(results):
355
+ batch.append(result)
356
+ if len(batch) >= batch_size: # Отправл��ем на реранк батчами
357
+ logging.info(f"Отправка на переранжирование батча {batch_num+1} ({len(batch)} фильмов)")
358
+ future = executor.submit(rerank_batch_voyage, query, batch)
359
+ futures.append(future)
360
+ batch = []
361
+ batch_num += 1
362
+
363
+ # Обработка остатка
364
+ if batch:
365
+ logging.info(f"Отправка на переранжирование батча {batch_num+1} ({len(batch)} фильмов)")
366
+ future = executor.submit(rerank_batch_voyage, query, batch)
367
+ futures.append(future)
368
+
369
+ # Сбор результатов
370
+ for i, future in enumerate(futures):
371
+ try:
372
+ batch_result = future.result()
373
+ reranked_results.extend(batch_result)
374
+ logging.info(f"Завершен реранк батча {i+1}")
375
+ except Exception as e:
376
+ logging.error(f"Ошибка при переранжировании батча {i+1}: {e}")
377
 
378
+ reranked_results = sorted(reranked_results, key=lambda x: x[1], reverse=True)
379
  logging.info("Переранжирование завершено.")
380
  return reranked_results
381
 
 
422
  FROM {embeddings_table} m, query_embedding
423
  ORDER BY similarity DESC
424
  LIMIT %s
425
+ """, (query_crc32, int(top_k * 1.1))) # Уменьшаем лимит до * 1.1
426
 
427
  results = cur.fetchall()
428
  logging.info(f"Найдено {len(results)} предварительных результатов поиска.")
 
441
  output += f"<h3>{movie['name']} ({movie['year']})</h3>\n"
442
  output += f"<p><strong>Жанры:</strong> {movie['genreslist']}</p>\n"
443
  output += f"<p><strong>Описание:</strong> {movie['description']}</p>\n"
444
+ output += f"<p><strong>Релевантность (Voyage AI reranker score):</strong> {score:.4f}</p>\n"
445
  output += "<hr>\n"
446
 
447
  search_time = time.time() - start_time