ESM-2 QLoRA
These are the checkpoints for the first ever QLoRA for ESM-2! You can load and use them similarly to the LoRA models.
This is the smallest esm2_t6_8M_UR50D
model, so the metrics aren't great.
Scaling to larger models for better metrics is in progress. These checkpoints were trained using
the 600K dataset. To replicate the training of QLoRA for ESM-2 models,
you can use the conda-environment.yml
file. However, for the next week or two (28/09/2023) you will need to uninstall transformers
and use this instead:
pip install --upgrade git+https://github.com/huggingface/transformers.git
In a couple of weeks, once the transformers library is updated, you should be able to simply use the latest version of transformers and gradient checkpointing will be fully enabled, and QLoRA compatibility should be fully integrated into ESM-2 models.
Data Curation and Preprocessing
To create your own datasets and perform the same data preprocessing as was used for this project, you will need to download a TSV file
from UniProt with the following columns (Protein families, Binding sites, Active sites, Protein sequence), and then you can use
this notebook for
separating out the test sequences by choosing random families to use (including all sequences in that family, with no overlap in with
the training data), filtering out proteins with incomplete annotations, merging the binding and active sites, converting them to binary
labels (0
for non-binding sites, 1
for binding sites), and splitting the sequences into non-overlapping chunks of 1000 residues or
less to accomodate the 1022 sized context window of ESM-2 models. This notebook will also allow you to reduce the size of your dataset
at the end. Note, this step is not currently ideal as it only selects proteins at random from the train and test datasets to keep and does
not take into account that proteins from small families are less likely to be chosen, biasing the models towards larger families. Due to
this shortcoming in our data preprocessing step, smaller models trained on smaller datasets are likely biased towards larger families.
Perhaps an approach that is biased towards smaller families would be better.
QLoRA Info
Note, we are only training 0.58% of the parameters, using only the query, key, and value weight matrices.
trainable params: 23682 || all params: 4075265 || trainable%: 0.5811155838945443
It was shown in the QLoRA paper that to obtain performance comparable to or better than full finetuning, the most important hyperparameter than can that can be adjusted is which weight matrices the LoRA adapters are applied to, with more being better. The rank and other hyperparameters such as the scaling factor alpha did not seem to matter. So, an important thing to investigate next would be to check and see if this transfers to protein language models as well. A general pattern showing that overfitting is improved by adding in adapters for more of the weight matrices is emerging, so more adapter layers seems to be better in that regard as well.
Testing for Overfitting
Checkpoint 1
Train/Test Split from 600K dataset:
Train metrics:
{'eval_loss': 0.31757092475891113,
'eval_accuracy': 0.8666164527145709,
'eval_precision': 0.12977997642311132,
'eval_recall': 0.8907064653559833,
'eval_f1': 0.2265505142278714,
'eval_auc': 0.8783913689919987,
'eval_mcc': 0.30996745466311043}
Test metrics:
{'eval_loss': 0.3398605287075043,
'eval_accuracy': 0.8557050926566265,
'eval_precision': 0.10792930844408741,
'eval_recall': 0.7726298654561553,
'eval_f1': 0.18940102955847055,
'eval_auc': 0.8150939843855006,
'eval_mcc': 0.2535956911257298}
Metrics for this checkpoint for these datasets can be found here.
Checkpoint 4
Train metrics:
{'eval_loss': 0.24070295691490173,
'eval_accuracy': 0.9018779246397052,
'eval_precision': 0.16624103834249204,
'eval_recall': 0.8651772818812425,
'eval_f1': 0.27889357183237473,
'eval_auc': 0.8839390799308487,
'eval_mcc': 0.3536803490333407}
Test metrics:
{'eval_loss': 0.26776671409606934,
'eval_accuracy': 0.8902711124906878,
'eval_precision': 0.13008662855482372,
'eval_recall': 0.7084623832213568,
'eval_f1': 0.219811797752809,
'eval_auc': 0.8013943890942485,
'eval_mcc': 0.2721459410994918}
- Downloads last month
- 12