|
import os |
|
import pickle |
|
import numpy as np |
|
from scipy import stats |
|
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score, matthews_corrcoef |
|
from transformers import AutoModelForTokenClassification, Trainer, AutoTokenizer, DataCollatorForTokenClassification |
|
from datasets import Dataset, concatenate_datasets |
|
from accelerate import Accelerator |
|
from peft import PeftModel |
|
import gc |
|
|
|
|
|
with open("/kaggle/input/550k-dataset/train_sequences_chunked_by_family.pkl", "rb") as f: |
|
train_sequences = pickle.load(f) |
|
with open("/kaggle/input/550k-dataset/test_sequences_chunked_by_family.pkl", "rb") as f: |
|
test_sequences = pickle.load(f) |
|
with open("/kaggle/input/550k-dataset/train_labels_chunked_by_family.pkl", "rb") as f: |
|
train_labels = pickle.load(f) |
|
with open("/kaggle/input/550k-dataset/test_labels_chunked_by_family.pkl", "rb") as f: |
|
test_labels = pickle.load(f) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t12_35M_UR50D") |
|
max_sequence_length = tokenizer.model_max_length |
|
|
|
|
|
def compute_metrics_for_batch(sequences_batch, labels_batch, models, voting='hard'): |
|
|
|
batch_tokenized = tokenizer(sequences_batch, padding=True, truncation=True, max_length=max_sequence_length, return_tensors="pt", is_split_into_words=False) |
|
|
|
|
|
batch_dataset = Dataset.from_dict({k: v for k, v in batch_tokenized.items()}) |
|
batch_dataset = batch_dataset.add_column("labels", labels_batch[:len(batch_dataset)]) |
|
|
|
|
|
labels_array = np.array([np.pad(label, (0, 1002 - len(label)), constant_values=-100) for label in batch_dataset["labels"]]) |
|
|
|
|
|
data_collator = DataCollatorForTokenClassification(tokenizer) |
|
trainers = [Trainer(model=model, data_collator=data_collator) for model in models] |
|
|
|
|
|
all_predictions = [trainer.predict(test_dataset=batch_dataset)[0] for trainer in trainers] |
|
|
|
if voting == 'hard': |
|
|
|
hard_predictions = [np.argmax(predictions, axis=2) for predictions in all_predictions] |
|
ensemble_predictions = stats.mode(hard_predictions, axis=0)[0][0] |
|
elif voting == 'soft': |
|
|
|
avg_predictions = np.mean(all_predictions, axis=0) |
|
ensemble_predictions = np.argmax(avg_predictions, axis=2) |
|
else: |
|
raise ValueError("Voting must be either 'hard' or 'soft'") |
|
|
|
|
|
mask_2d = labels_array != -100 |
|
|
|
|
|
true_labels_list = [label[mask_2d[idx]] for idx, label in enumerate(labels_array)] |
|
true_labels = np.concatenate(true_labels_list) |
|
flat_predictions_list = [ensemble_predictions[idx][mask_2d[idx]] for idx in range(ensemble_predictions.shape[0])] |
|
flat_predictions = np.concatenate(flat_predictions_list).tolist() |
|
|
|
|
|
accuracy = accuracy_score(true_labels, flat_predictions) |
|
precision, recall, f1, _ = precision_recall_fscore_support(true_labels, flat_predictions, average='binary') |
|
auc = roc_auc_score(true_labels, flat_predictions) |
|
mcc = matthews_corrcoef(true_labels, flat_predictions) |
|
|
|
return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "auc": auc, "mcc": mcc} |
|
|
|
|
|
def evaluate_in_batches(sequences, labels, models, dataset_name, voting, batch_size=1000, print_first_n=5): |
|
num_batches = len(sequences) // batch_size + int(len(sequences) % batch_size != 0) |
|
metrics_list = [] |
|
|
|
for i in range(num_batches): |
|
start_idx = i * batch_size |
|
end_idx = start_idx + batch_size |
|
batch_metrics = compute_metrics_for_batch(sequences[start_idx:end_idx], labels[start_idx:end_idx], models, voting) |
|
|
|
|
|
if i < print_first_n: |
|
print(f"{dataset_name} - Batch {i+1}/{num_batches} metrics: {batch_metrics}") |
|
|
|
metrics_list.append(batch_metrics) |
|
|
|
|
|
avg_metrics = {key: np.mean([metrics[key] for metrics in metrics_list]) for key in metrics_list[0]} |
|
return avg_metrics |
|
|
|
|
|
accelerator = Accelerator() |
|
base_model_path = "facebook/esm2_t12_35M_UR50D" |
|
base_model = AutoModelForTokenClassification.from_pretrained(base_model_path) |
|
lora_model_paths = [ |
|
"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_cp1", |
|
"AmelieSchreiber/esm2_t12_35M_lora_binding_sites_v2_cp1", |
|
] |
|
models = [PeftModel.from_pretrained(base_model, path) for path in lora_model_paths] |
|
models = [accelerator.prepare(model) for model in models] |
|
|
|
|
|
test_metrics_soft = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='soft') |
|
train_metrics_soft = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='soft') |
|
test_metrics_hard = evaluate_in_batches(test_sequences, test_labels, models, "test", voting='hard') |
|
train_metrics_hard = evaluate_in_batches(train_sequences, train_labels, models, "train", voting='hard') |
|
|
|
print("Test metrics (soft voting):", test_metrics_soft) |
|
print("Train metrics (soft voting):", train_metrics_soft) |
|
print("Test metrics (hard voting):", test_metrics_hard) |
|
print("Train metrics (hard voting):", train_metrics_hard) |
|
|