Transcription Factor Binding Prediction with OmniGenBench
This notebook provides a step-by-step guide to extend OmniGenBench to the TFB task based on the OmniGenome-52M model on the DeepSEA dataset. The goal is to perform multi-label classification to predict the binding sites of various transcription factors based on DNA sequences.
Dataset Description: The dataset used in this notebook is derived from the DeepSEA dataset, which is designed for studying the effects of non-coding variants. It consists of DNA sequences of 1000 base pairs, each associated with 919 binary labels corresponding to various chromatin features (transcription factor binding, DNase I sensitivity, and histone marks). For this task, we use a preprocessed version available from the yangheng/tfb_prediction dataset on Hugging Face.
Estimated Runtime: The total runtime for this notebook depends on the hardware and the number of training examples (MAX_EXAMPLES). On a single NVIDIA RTX 4090 GPU, training with the default settings (MAX_EXAMPLES=100000, EPOCHS=10) takes approximately 1–2 hours. For a quick test run with MAX_EXAMPLES=1000, it should take about 5–10 minutes.
1. Setup & Installation
First, let's ensure all the required packages are installed. If you have already installed them, you can skip this cell. Otherwise, uncomment and run the cell to install the dependencies.
2. Import Libraries
Import all the necessary libraries for genomic data processing, model inference, and analysis.
import autocuda
import importlib, sys
import findfile
utils_spec = importlib.util.spec_from_file_location("utils", "utils.py")
utils = importlib.util.module_from_spec(utils_spec)
utils_spec.loader.exec_module(utils)
sys.modules["utils"] = utils
# Import reusable interfaces from local utils
from utils import (
download_deepsea_dataset,
load_tokenizer_and_model,
build_datasets,
create_dataloaders,
run_finetuning,
run_inference,
)
print("Libraries imported successfully.")
3. Configuration
Here, we define all the hyperparameters and settings for our experiment. This centralized configuration makes it easy to modify parameters and track experiments.
# --- Data File Paths ---
LOCAL_PATH = "deepsea_tfb_prediction"
download_deepsea_dataset(LOCAL_PATH)
TRAIN_FILE = findfile.find_cwd_file(['train', 'jsonl'])
TEST_FILE = findfile.find_cwd_file(['test', 'jsonl'])
VALID_FILE = findfile.find_cwd_file(['valid', 'jsonl'])
# --- Model Configuration ---
# --- Available Models for Testing ---
AVAILABLE_MODELS = [
'yangheng/OmniGenome-52M',
'yangheng/OmniGenome-186M',
'yangheng/OmniGenome-v1.5',
# You can add more models here as needed,
# 'DNABERT-2-117M',
# 'hyenadna-large-1m-seqlen-hf',
# 'InstaDeepAI/nucleotide-transformer-500m-human-ref',
# 'multimolecule/rnafm', # RNA-specific models
# 'multimolecule/rnabert',
# 'SpliceBERT-510nt', # Splice-specific model
]
MODEL_NAME_OR_PATH = AVAILABLE_MODELS[0]
USE_CONV_LAYERS = False # Set to True to add DeepSEA-style convolutional layers on top of OmniGenome (not used in this demo)
# --- Training Hyperparameters ---
EPOCHS = 50
LEARNING_RATE = 5e-5
WEIGHT_DECAY = 1e-3
BATCH_SIZE = 64
PATIENCE = 3 # For early stopping
MAX_LENGTH = 200 # The length of the DNA sequence to be processed
SEED = 45
# LABEL_INDICES = [0] # Example indices for the first 10 transcription factors
LABEL_INDICES = list(range(919))
MAX_EXAMPLES = 1000000 # Use a smaller number for quick testing (e.g., 1000), or None for all data
DEVICE = autocuda.auto_cuda()
print(f"Using device: {DEVICE}")
4. Model and Dataset Initialization
Initialize tokenizer and model, then build datasets and dataloaders using utilities for a concise workflow.
# 1. Initialize Tokenizer and Model
print("--- Initializing Tokenizer and Model ---")
# Use utility to load tokenizer and model
label_count = len(LABEL_INDICES)
tokenizer, model = load_tokenizer_and_model(
MODEL_NAME_OR_PATH,
num_labels=label_count,
threshold=0.5,
device=DEVICE,
)
# 2. Create Datasets via utility
print("\n--- Creating Datasets ---")
train_set, valid_set, test_set = build_datasets(
tokenizer=tokenizer,
train_file=TRAIN_FILE,
test_file=TEST_FILE,
valid_file=VALID_FILE,
max_length=MAX_LENGTH,
max_examples=MAX_EXAMPLES,
label_indices=LABEL_INDICES,
)
# Create DataLoaders for batching (utils)
train_loader, valid_loader, test_loader = create_dataloaders(
train_set=train_set,
valid_set=valid_set,
test_set=test_set,
batch_size=BATCH_SIZE,
)
print("\n--- Initialization Complete ---")
print(f"Training set size: {len(train_set)}")
print(f"Test set size: {len(test_set)}")
if valid_set:
print(f"Validation set size: {len(valid_set)}")
5. Finetuning
Fine-tune the model using AccelerateTrainer (invoked through the run_finetuning compatibility wrapper). Early stopping monitors validation ROC AUC when a validation set is provided.
# Train with utilities
print("--- Starting Training ---")
trainer, metrics_best = run_finetuning(
model=model,
train_loader=train_loader,
valid_loader=valid_loader,
test_loader=test_loader,
epochs=EPOCHS,
learning_rate=LEARNING_RATE,
weight_decay=WEIGHT_DECAY,
patience=PATIENCE,
device=DEVICE,
save_dir="tfb_model",
)
print(metrics_best)
print("--- Training Finished ---")
6. Inference Example
Run a single-sequence prediction using the persisted fine-tuned model. The same preprocessing pathway (encode_tokens) ensures parity with training.
sample_sequence = "AGCT" * (MAX_LENGTH // 4) # Construct sequence of required length
outputs = run_inference(
model_dir="tfb_model",
tokenizer=tokenizer,
sample_sequence=sample_sequence,
max_length=MAX_LENGTH,
device=DEVICE,
)
predictions = outputs.get('predictions', None)
probabilities = outputs.get('probabilities', None)
print(f"Input sequence length: {len(sample_sequence)} bp")
if predictions is not None:
print(f"Number of predicted labels: {len(predictions)}")
print("\n--- Predictions for the first 10 TFs ---")
for i in range(min(10, len(predictions))):
pred_label = 'Binds' if int(predictions[i]) == 1 else 'Does not bind'
if probabilities is not None:
try:
p = float(probabilities[i])
print(f"Label {i+1}: Prediction={pred_label}, Prob={p:.4f}")
except Exception:
print(f"Label {i+1}: Prediction={pred_label}")
else:
print(f"Label {i+1}: Prediction={pred_label}")
else:
print("No 'predictions' returned by model.inference; verify the saved model and inference API.")