Model Card for Omni-DNA

Requirement

pip install datasets ai2-olmo

Overview

Omni-DNA is a cross-modal, multi-task genomic foundation model designed to generalize across diverse genomic tasks. Unlike previous Genomic Foundation Models (GFMs), which require separate fine-tuning for each task, Omni-DNA leverages auto-regressive transformer-based training and multi-task fine-tuning, enabling a single model to perform a wide range of genomic tasks with state-of-the-art performance.

Omni-DNA models range from 20M to 1B parameters and support tasks such as sequence annotation, regulatory element classification, acetylation/methylation prediction, and DNA2Function/DNA2Image mapping.

Base Model Details

Size Training Tokens Layers Hidden Size Attention Heads Context Length
Omni-DNA 20M 300B 8 256 8 250
Omni-DNA 60M 300B 8 512 8 250
Omni-DNA 116M 300B 12 768 16 250
Omni-DNA 300M 300B 16 1024 16 250
Omni-DNA 700M 300B 16 1536 16 250
Omni-DNA 1B 300B 16 2048 16 250

Model Description

  • Supported by: Microsoft Research Asia
  • Model type: Auto-regressive transformer-based genomic model
  • License: mit
  • Date cutoff: 2024
  • Contact: Research inquiries at [email protected]

Model Sources

Capabilities

Omni-DNA is trained to perform multiple genomic tasks including:

  • Regulatory Element Classification: Enhancer/promoter/splice site detection
  • Histone Modification Prediction: Acetylation and methylation state identification
  • Genomic Function Annotation: DNA-to-text mapping (DNA2Function)
  • Cross-modal Learning: DNA-to-image mapping (DNA2Image)
  • Multi-task Learning: A single model can solve multiple tasks simultaneously

Usage

As a Generative AutoRegressive Model


from transformers import AutoModelForCausalLM, AutoTokenizer

# Load tokenizer and model
model_tokenizer_path = "anon/Omni-DNA-116M"
tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path)
model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path).to('cuda')

def generate(message, task_type, model=model, sample_num=1):
    """Generate an output sequence given an input message."""    
    # Tokenize the input
    tokenized_message = tokenizer(
        [message], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True
    ).to('cuda')
    # Generate response (deterministic mode)
    response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
    # Alternative: Use stochastic sampling with top-k and top-p filtering
    # response = model.generate(**tokenized_message, max_new_tokens=1, do_sample=True, top_k=300, top_p=0.95)
    # Decode the generated sequence
    reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0]
    # Remove spaces and extract relevant output
    reply = reply.replace(" ", "")
    return reply

# Example usage:
task = "DNA sequence classification"
message = "ATGCGTACGTAGCTAGCTAGCTAGCTAGCTA"
output = generate(message, task)
print(f"Generated output: {output}")

Attaching Classification Head

from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Load the model with a classification head
model = AutoModelForSequenceClassification.from_pretrained("zehui127/Omni-DNA-116M", num_labels=2, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("zehui127/Omni-DNA-116M", trust_remote_code=True)

### Finetuning the loaded model on the target task ... ###
# Define train_dataset, compute_metrics ...
trainer = transformers.Trainer(model=model,
                                tokenizer=tokenizer,
                                args=training_args,
                                compute_metrics=compute_metrics,
                                train_dataset=train_data,
                                eval_dataset=val_data,
                                data_collator=collate_fn)

# ... After finetuning: Example DNA sequence
sequence = "ATGCGTACGTAGCTAGCTAGCTAGCTAGCTA"

# Tokenize input sequence
inputs = tokenizer(sequence, return_tensors="pt")

# Forward pass
outputs = model(**inputs)

# Extract classification logits and get the predicted label
logits = outputs.logits
predicted_class = logits.argmax(dim=-1).item()

print(f"Predicted class: {predicted_class}")

Supervised Finetuning (Make Prediction in the Generative Manner)

from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTTrainer, SFTConfig, DataCollatorForCompletionOnlyLM
from datasets import load_dataset, concatenate_datasets

# Load the pre-trained model and tokenizer
model_name = "zehui127/Omni-DNA-116M"
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

# Load and process dataset (assumes JSON format)
dataset = load_dataset("json", data_files={"train": "path/to/train.json"})
dataset = dataset["train"]

# Group dataset by task type (if necessary)
def group_by_task_type(dataset):
    task_types = set(dataset['task'])
    task_datasets = {}
    for task in task_types:
        task_datasets[task] = dataset.filter(lambda x: x['task'] == task)
    return task_datasets

# Example formatting function for generative fine-tuning
def formatting_prompts_func(example):
    return [f"{example['instruction']} {example['task']} [SEP] {example['output']}"]

response_template = "[SEP]"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

# Fine-tuning configuration
training_args = SFTConfig(
    per_device_train_batch_size=6,
    per_device_eval_batch_size=8,
    save_total_limit=1,
    max_seq_length=512,
    output_dir="./finetuned_omni_dna",
    save_safetensors=False,
    num_train_epochs=10,
    save_strategy="epoch",
    neftune_noise_alpha=5,  # Apply NEFT for regularization
)

# Trainer setup
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    args=training_args,
    formatting_func=formatting_prompts_func,
    data_collator=collator,
)

# Train the model
trainer.train()
Downloads last month
272
Safetensors
Model size
116M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Collection including zehui127/Omni-DNA-116M