Skip to content

Large-scale Benchmarking with OmniGenBench Using LoRA Fine-tuning

1. Setup & Installation

pip install omnigenbench transformers peft accelerate bitsandbytes

2. Import Libraries

import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from omnigenbench import AutoBench

print("Libraries imported successfully.")

3. Configuration

# --- General Settings ---
BENCHMARK = "GUE"  # Benchmark suite to use, e.g., "GUE", "RGB"
BATCH_SIZE = 32
PATIENCE = 3
EPOCHS = 20
MAX_EXAMPLES = 1000  # Use a smaller number for quick testing, set to None for all data
SEED = random.randint(0, 1000)

# --- Model Selection ---
GFM_TO_TUNE = 'yangheng/OmniGenome-52M'

AVAILABLE_GFMS = [
    'yangheng/OmniGenome-52M',
    'yangheng/OmniGenome-186M',
    'yangheng/OmniGenome-v1.5',
    'zhihan1996/DNABERT-2-117M',
    'LongSafari/hyenadna-large-1m-seqlen-hf',
    'InstaDeepAI/nucleotide-transformer-v2-100m-multi-species',
    'kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16',
    'multimolecule/rnafm',
    # Evo models (need intsall evo package)
    # 'arcinstitute/evo-1-131k-base',
    # 'SpliceBERT-510nt',
]

# --- LoRA Configuration ---
LORA_CONFIGS = {
    "OmniGenome-186M": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["key", "value", "dense"], "bias": "none"
    },
    "caduceus-ph_seqlen-131k_d_model-256_n_layer-16": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["in_proj", "x_proj", "out_proj"], "bias": "none"
    },
    "rnamsm": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["q_proj", "v_proj", "out_proj"], "bias": "none"
    },
    "rnafm": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["key", "value", "dense"], "bias": "none"
    },
    "rnabert": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["key", "value", "dense"], "bias": "none"
    },
    "agro-nucleotide-transformer-1b": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["key", "value", "dense"], "bias": "none"
    },
    "SpliceBERT-510nt": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["key", "value", "dense"], "bias": "none"
    },
    "DNABERT-2-117M": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["Wqkv", "dense"], "bias": "none"
    },
    "3utrbert": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["key", "value", "dense"], "bias": "none"
    },
    "hyenadna-large-1m-seqlen-hf": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["in_proj", "out_proj"], "bias": "none"
    },
    "nucleotide-transformer-v2-100m-multi-species": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": ["key", "value", "dense"], "bias": "none"
    },
    "evo-1-131k-base": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": [
            "Wqkv", "out_proj",
            "mlp",
            "projections",
            "out_filter_dense"
        ],
        "bias": "none"
    },
    "evo-1.5-8k-base": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": [
            "Wqkv", "out_proj",
            "l1", "l2", "l3",
            "projections",
            "out_filter_dense"
        ],
        "bias": "none"
    },
    "evo-1-8k-base": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": [
            "Wqkv", "out_proj",
            "l1", "l2", "l3",
            "projections",
            "out_filter_dense"
        ],
        "bias": "none"
    },
    "evo2_7b": {
        "r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
        "target_modules": [
            "Wqkv", "out_proj",
            "l1", "l2", "l3",
            # "projections",
            "out_filter_dense"
        ],
        "bias": "none"
    },
}

print(f"Configuration loaded:")
print(f"  - GFM to Tune: {GFM_TO_TUNE}")
print(f"  - Benchmark: {BENCHMARK}")
print(f"  - Epochs: {EPOCHS}")
print(f"  - LoRA Config: {LORA_CONFIGS.get(GFM_TO_TUNE.split('/')[-1], None)}")

4. Model-Specific Loading

def load_gfm_and_tokenizer(gfm_name):
    """Loads a GFM and its tokenizer, handling special cases."""
    print(f"\nLoading model and tokenizer for: {gfm_name}")

    if 'multimolecule' in gfm_name:
        from multimolecule import RnaTokenizer, AutoModelForTokenPrediction
        tokenizer = RnaTokenizer.from_pretrained(gfm_name)
        model = AutoModelForTokenPrediction.from_pretrained(
            gfm_name, trust_remote_code=True
        ).base_model
        print("Loaded multimolecule model with custom classes.")

    elif 'evo-1' in gfm_name:
        config = AutoConfig.from_pretrained(gfm_name, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            gfm_name, config=config, trust_remote_code=True
        ).backbone
        tokenizer = AutoTokenizer.from_pretrained(gfm_name, trust_remote_code=True)
        tokenizer.pad_token_id = tokenizer.pad_token_type_id

        model.config = config
        model.config.pad_token_id = tokenizer.pad_token_id
        model.unembed.unembed = lambda x: x
        print("Loaded Evo model with custom patching.")

    else:
        tokenizer = None
        model = gfm_name
        print("Using standard model name for AutoBench.")

    return model, tokenizer

print("Model loading function defined.")

5. Running LoRA Fine-tuning

model, tokenizer = load_gfm_and_tokenizer(GFM_TO_TUNE)

print(f"\nInitializing AutoBench for benchmark: {BENCHMARK}")
bench = AutoBench(
    benchmark=BENCHMARK,
    model_name_or_path=model,
    tokenizer=tokenizer,
    overwrite=True,
    trainer='native',  # 'native' or 'accelerate'
    autocast='fp16',  # 'fp16', 'bf16', or 'fp32'
    device='cuda',
)

lora_config = LORA_CONFIGS.get(GFM_TO_TUNE.split('/')[-1], None)

print(f"\nStarting LoRA fine-tuning for {GFM_TO_TUNE}...")
bench.run(
    batch_size=BATCH_SIZE,
    gradient_accumulation_steps=1,
    patience=PATIENCE,
    max_examples=MAX_EXAMPLES,
    seeds=SEED,
    epochs=EPOCHS,
    lora_config=lora_config,
)

print("\n🎉 LoRA fine-tuning complete!")
print("Check the 'autobench_logs' and 'autobench_evaluations' directories for results.")

6. Multi-Model LoRA Fine-tuning (Optional)

print("Starting multi-model LoRA fine-tuning...")
print("="*50)

for gfm in AVAILABLE_GFMS:
    try:
        model, tokenizer = load_gfm_and_tokenizer(gfm)
        bench = AutoBench(
            benchmark=BENCHMARK,
            model_name_or_path=model,
            tokenizer=tokenizer,
            overwrite=True,
            trainer='native',
            autocast='fp16',
            device='cuda',
        )
        lora_config = LORA_CONFIGS.get(gfm.split('/')[-1], None)

        print(f"\nStarting LoRA fine-tuning for {gfm}...")
        bench.run(
            batch_size=BATCH_SIZE,
            patience=PATIENCE,
            max_examples=MAX_EXAMPLES,
            seeds=SEED,
            epochs=EPOCHS,
            lora_config=lora_config,
        )
        print(f"Finished fine-tuning for {gfm}.")
        print("="*50)

    except Exception as e:
        print(f"An error occurred while processing {gfm}: {e}")
        print("="*50)
        continue

print("\n🎉 All models have been processed!")