Skip to content

Translation Efficiency Prediction with OmniGenBench

This notebook demonstrates how to fine-tune a Genomic Foundation Model (GFM) to predict Rice translation efficiency (TE) from mRNA sequences using OmniGenBench. The example focusing on plant RNA translation efficiency.

Background (PlantRNA-FM): PlantRNA-FM ("An interpretable RNA foundation model for exploring functional RNA motifs in plants") introduces an RNA foundation model tailored to plant genomics, highlighting interpretability and motif discovery capabilities. TE prediction is a representative downstream task where models learn sequence determinants associated with efficient translation. In this demo, we use a small rice 5'UTR/mRNA TE dataset to illustrate fine-tuning and evaluation within OmniGenBench.

Task type: Binary sequence classification (High-TE vs Low-TE) Input: RNA sequences (string), up to a configurable max_length Label space: {0: Low-TE, 1: High-TE} Estimated runtime: On a single NVIDIA RTX 4090, a short training run on this toy dataset typically takes ~10–30 minutes depending on epochs/model size.

1. Setup & Installation

This section handles the initial setup, including installing necessary packages. If dependencies are already available, you can skip the installation cell.

pip install torch transformers pandas autocuda multimolecule biopython scipy scikit-learn tqdm dill findfile requests omnigenbench

2. Import Libraries

Import all necessary libraries for data processing, model training, and analysis.

import warnings
import importlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
from autocuda import auto_cuda

# Import utilities from the local utils.py file
utils_spec = importlib.util.spec_from_file_location("utils", "utils.py")
utils = importlib.util.module_from_spec(utils_spec)
utils_spec.loader.exec_module(utils)

warnings.filterwarnings('ignore')
print("Libraries imported successfully!")

3. Configuration & Data

Set up the analysis parameters, file paths, and model selection here. You can easily change the MODEL_NAME to test different GFMs.

Model Selection

Choose the Genomic Foundation Model to fine-tune.

# Using utils for reusable logic
from utils import run_finetuning
print("Core classes and functions imported from utils.")

# --- Available Models for Testing ---
AVAILABLE_MODELS = [
    # 'yangheng/OmniGenome-52M',
    # 'yangheng/OmniGenome-186M',
    'yangheng/OmniGenome-v1.5',
]
MODEL_NAME = AVAILABLE_MODELS[0]  # Model to use for predictions
print(f"Selected model: {MODEL_NAME}")

Hyperparameter and Dataset Configuration

Define the training hyperparameters and paths to the dataset files.

import findfile

# --- Training Hyperparameters ---
EPOCHS = 10
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 1e-5
BATCH_SIZE = 4
MAX_LENGTH = 1024
SEED = 42

# --- Dataset Configuration ---
LOCAL_PATH = "te_rice_dataset"  # Local directory for the dataset
from utils import download_te_dataset
download_te_dataset(LOCAL_PATH)  # Download the VEP dataset if not already available

# --- Dataset Configuration ---
TRAIN_FILE = findfile.find_cwd_file("train.json")  # training split
VALID_FILE = findfile.find_cwd_file("valid.json")  # validation split (optional)
TEST_FILE = findfile.find_cwd_file("test.json")  # test split
# --- Label Mapping ---
# The task is binary classification: 1 for TE, 0 for non-TE.
LABEL2ID = {"0": 0, "1": 1}

print(f"Selected model: {MODEL_NAME}")

4. Main Analysis Pipeline

This section executes the training/evaluation pipeline using the configuration defined above. The run_training function from examples/translation_efficiency_prediction/utils.py orchestrates tokenization, dataset creation, training, and evaluation.

# Import main pipeline from utils for a concise demo
from utils import run_finetuning

print("Main analysis pipeline imported from utils.")

print("=" * 50)

# Run the analysis
metrics = run_finetuning(
    model_name=MODEL_NAME,
    train_file=TRAIN_FILE,
    valid_file=VALID_FILE,
    test_file=TEST_FILE,
    label2id=LABEL2ID,
    epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    batch_size=BATCH_SIZE,
    max_length=MAX_LENGTH,
    seed=SEED,
)

print("=" * 50)
print("Analysis completed!")

# Print final validation and test metrics (if available)
if metrics.get('valid'):
    print("\nValidation Set Performance (last epoch):")
    for key, value in metrics['valid'][-1].items():
        print(f"{key}: {value:.4f}")

if metrics.get('test'):
    print("\nTest Set Performance (best checkpoint):")
    for key, value in metrics['test'][-1].items():
        print(f"{key}: {value:.4f}")

5. Inference Example

This section demonstrates how to run inference on a single sequence using the fine-tuned model. The encode_tokens function ensures the same preprocessing as during training.

from omnigenbench import ModelHub

model = ModelHub.load("finetuned_te_model").to(auto_cuda()).to(torch.float16)

# Example sequence for inference
sample_sequence1 = "AAACCAACAAAATGCAGTAGAAGTACTCTCGAGCTATAGTCGCGACGTGCTGCCCCGCAGGAGTACAGTAGTAGTACAACGTAAGCGGGAGCAACAGACTCCCCCCCTGCAACCCACTGTGCCTGTGCCCTCGACGCGTCTCCGTCGCTTTGGCAAATGTCACGTACATATTACCGTCTCAGGCTCTCAGCCATGCTCCCTACCACCCCTGCAGCGAAGCAAAAGCCACGCACGCGGCGCCTGACATGTAACAGGACTAGACCATCTTGTTCATTTCCCGCACCCCCTCCTCTCCTCTTCCTCCATCTGCCTCTTTAAAACAGTAAAAATAACCGTGCATCCCCTGGGCAAAATCTCTCCCATACATACACTACAGCGGCGAACCTTTCCTTATTCTCGCAACGCCTCGGTAACGGGCAGCGCCTGCTCCGCGCCGCGGTTGCGAGTTCGGGAAGGCGGCCGGAGTCGCGGGGAGGAGAGGGAGGATTCGATCGGCCAGA"  # High-TE sequence

sample_sequence2 = "TGGAGATGGGCAGATGGCACACAAAACATGAATAGAAAACCCAAAAGGAAGGATGAAAAAAACACACACACACACACACACAAAACACAGAGAGAGAGAGAGAGAGAGAGCGAGAAAAGAAAAGAAAAAACCAATTCTTTTGGTCTCTTCCCTCTCCGTTTGTCGTGTCGAAGCCTTTGCCCCCACCACCTCCTCCTCTCCTCTCCCTTCCTCCCCTCCTCCCCATCTCGCTCTCCTCCCTCCTCTCTCCTCTCCTCGTCTCCTCTTCCTCTCCATTCCATTGGCCATTCCATTCCATTCCACCCCCCATGAAACCCCAAACCCTCGTCGGCCTCGCCGCGCTCGCGTAGCGCACCCGCCCTTCTCCTCTCGCCGGTGGTCCGCCGCCAGCCTCCCCCCACCCGATCCCGCCGCCCCCCCCGCCTTCACCCCGCCCACGCGGACGCATCCGATCCCGCCGCATCGCCGCGCGGGGGGGGGGGGGGGGGGGGGAGGGCACG "  # Low-TE sequence

# Run inference on the sample sequences
outputs = model.inference(sample_sequence1)
print(f"Sample sequence 1 prediction: {outputs}")
outputs = model.inference(sample_sequence2)
print(f"Sample sequence 2 prediction: {outputs}")

6. Visualization

In this section, we visualize validation metrics across epochs to assess learning dynamics.

Plot Validation Curves

We plot macro F1 across epochs. Additional metrics (e.g., MCC) can be added if enabled in utils.py or dataset config.

# Results Overview — Quick summary of metrics
if metrics.get('valid'):
    print('Validation (last epoch):')
    for k, v in metrics['valid'][-1].items():
        print(f"{k}: {v:.4f}")

if metrics.get('test'):
    print('\nTest (best checkpoint):')
    for k, v in metrics['test'][-1].items():
        print(f"{k}: {v:.4f}")
# Visualization — Plot validation curve
valid_key = 'valid' if 'valid' in metrics else ('eval' if 'eval' in metrics else None)
if valid_key is None:
    print("No validation metrics found for plotting.")
else:
    valid_df = pd.DataFrame(metrics[valid_key])
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, ax = plt.subplots(1, 1, figsize=(7, 5))
    if 'f1_score' in valid_df.columns:
        sns.lineplot(data=valid_df, x=valid_df.index, y='f1_score', ax=ax, label='Validation F1 (Macro)')
        ax.set_ylabel('F1 Score')
    elif 'matthews_corrcoef' in valid_df.columns:
        sns.lineplot(data=valid_df, x=valid_df.index, y='matthews_corrcoef', ax=ax, label='Validation MCC')
        ax.set_ylabel('MCC')
    else:
        first_col = [c for c in valid_df.columns if isinstance(valid_df[c].iloc[-1], (int, float))]
        if first_col:
            sns.lineplot(data=valid_df, x=valid_df.index, y=first_col[0], ax=ax, label=first_col[0])
            ax.set_ylabel(first_col[0])
        else:
            print("Validation metrics exist but no numeric columns to plot.")
    ax.set_title('Validation Metric across Epochs')
    ax.set_xlabel('Epoch')
    ax.legend()
    plt.tight_layout()
    plt.show()

References

PlantRNA-FM: "An interpretable RNA foundation model for exploring functional RNA motifs in plants" (Nature Machine Intelligence, 2024).