Проект: Чат-бот с использованием модели ruT5-base для ответов на вопросы

Описание

Этот проект представляет собой систему, которая использует предобученную модель ruT5-base для генерации ответов на вопросы, основанных на предоставленном контексте. Я дообучаю модель на датасете SberQUAD, адаптируя её для задач вопросно-ответного взаимодействия на русском языке.

Датасет

Я использую датасет SberQUAD, который содержит примеры вопросов и ответов на них в контексте различных текстов. Датасет разбит на тренировочные, валидационные и тестовые части.

Архитектура модели

В качестве базовой модели используется ruT5-base — Encoder-Decoder модель, оптимизированная для задач на русском языке. Модель была дополнительно дообучена на кастомных данных для улучшения генерации ответов на основе предоставленного контекста.

Параметры обучения

Для обучения использовались следующие параметры:

output_dir="./models",
optim="adafactor",
num_train_epochs=1, # в идеале 2 эпохи
do_train=True,
gradient_checkpointing=True,
bf16=True,
per_device_train_batch_size=8,
per_device_eval_batch_size=12,
gradient_accumulation_steps=4,
logging_dir="./logs",
report_to="wandb",
logging_steps=10,
save_strategy="steps",
save_steps=5000,
evaluation_strategy="steps",
eval_steps=300,
learning_rate=3e-5,
predict_with_generate=False,
generation_max_length=64

К сожалению, мне не хватило вычислительного времени на Google Collab, поэтому модель была обучена только на одной эпохе с ~1416 шагами.

Результаты обучения

Шаг Loss на валидации Sbleu Chr F Rouge1 Rouge2 Rougel
300 1.025008 18.206400 62.316300 0.110400 0.035200 0.109800
600 1.007530 18.523100 62.564700 0.113300 0.036500 0.112800
900 0.959073 18.869000 63.001700 0.115100 0.035600 0.114600
1200 0.944776 18.656300 62.819800 0.115400 0.035800 0.115000
Downloads last month
17
Safetensors
Model size
223M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Model tree for RichelieuGVG/tinek_sample_model

Finetuned
(12)
this model

Dataset used to train RichelieuGVG/tinek_sample_model