Large-scale Benchmarking with OmniGenBench Using LoRA Fine-tuning
1. Setup & Installation
2. Import Libraries
import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from omnigenbench import AutoBench
print("Libraries imported successfully.")
3. Configuration
# --- General Settings ---
BENCHMARK = "GUE" # Benchmark suite to use, e.g., "GUE", "RGB"
BATCH_SIZE = 32
PATIENCE = 3
EPOCHS = 20
MAX_EXAMPLES = 1000 # Use a smaller number for quick testing, set to None for all data
SEED = random.randint(0, 1000)
# --- Model Selection ---
GFM_TO_TUNE = 'yangheng/OmniGenome-52M'
AVAILABLE_GFMS = [
'yangheng/OmniGenome-52M',
'yangheng/OmniGenome-186M',
'yangheng/OmniGenome-v1.5',
'zhihan1996/DNABERT-2-117M',
'LongSafari/hyenadna-large-1m-seqlen-hf',
'InstaDeepAI/nucleotide-transformer-v2-100m-multi-species',
'kuleshov-group/caduceus-ph_seqlen-131k_d_model-256_n_layer-16',
'multimolecule/rnafm',
# Evo models (need intsall evo package)
# 'arcinstitute/evo-1-131k-base',
# 'SpliceBERT-510nt',
]
# --- LoRA Configuration ---
LORA_CONFIGS = {
"OmniGenome-186M": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["key", "value", "dense"], "bias": "none"
},
"caduceus-ph_seqlen-131k_d_model-256_n_layer-16": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["in_proj", "x_proj", "out_proj"], "bias": "none"
},
"rnamsm": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["q_proj", "v_proj", "out_proj"], "bias": "none"
},
"rnafm": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["key", "value", "dense"], "bias": "none"
},
"rnabert": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["key", "value", "dense"], "bias": "none"
},
"agro-nucleotide-transformer-1b": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["key", "value", "dense"], "bias": "none"
},
"SpliceBERT-510nt": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["key", "value", "dense"], "bias": "none"
},
"DNABERT-2-117M": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["Wqkv", "dense"], "bias": "none"
},
"3utrbert": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["key", "value", "dense"], "bias": "none"
},
"hyenadna-large-1m-seqlen-hf": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["in_proj", "out_proj"], "bias": "none"
},
"nucleotide-transformer-v2-100m-multi-species": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": ["key", "value", "dense"], "bias": "none"
},
"evo-1-131k-base": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": [
"Wqkv", "out_proj",
"mlp",
"projections",
"out_filter_dense"
],
"bias": "none"
},
"evo-1.5-8k-base": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": [
"Wqkv", "out_proj",
"l1", "l2", "l3",
"projections",
"out_filter_dense"
],
"bias": "none"
},
"evo-1-8k-base": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": [
"Wqkv", "out_proj",
"l1", "l2", "l3",
"projections",
"out_filter_dense"
],
"bias": "none"
},
"evo2_7b": {
"r": 8, "lora_alpha": 32, "lora_dropout": 0.1,
"target_modules": [
"Wqkv", "out_proj",
"l1", "l2", "l3",
# "projections",
"out_filter_dense"
],
"bias": "none"
},
}
print(f"Configuration loaded:")
print(f" - GFM to Tune: {GFM_TO_TUNE}")
print(f" - Benchmark: {BENCHMARK}")
print(f" - Epochs: {EPOCHS}")
print(f" - LoRA Config: {LORA_CONFIGS.get(GFM_TO_TUNE.split('/')[-1], None)}")
4. Model-Specific Loading
def load_gfm_and_tokenizer(gfm_name):
"""Loads a GFM and its tokenizer, handling special cases."""
print(f"\nLoading model and tokenizer for: {gfm_name}")
if 'multimolecule' in gfm_name:
from multimolecule import RnaTokenizer, AutoModelForTokenPrediction
tokenizer = RnaTokenizer.from_pretrained(gfm_name)
model = AutoModelForTokenPrediction.from_pretrained(
gfm_name, trust_remote_code=True
).base_model
print("Loaded multimolecule model with custom classes.")
elif 'evo-1' in gfm_name:
config = AutoConfig.from_pretrained(gfm_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
gfm_name, config=config, trust_remote_code=True
).backbone
tokenizer = AutoTokenizer.from_pretrained(gfm_name, trust_remote_code=True)
tokenizer.pad_token_id = tokenizer.pad_token_type_id
model.config = config
model.config.pad_token_id = tokenizer.pad_token_id
model.unembed.unembed = lambda x: x
print("Loaded Evo model with custom patching.")
else:
tokenizer = None
model = gfm_name
print("Using standard model name for AutoBench.")
return model, tokenizer
print("Model loading function defined.")
5. Running LoRA Fine-tuning
model, tokenizer = load_gfm_and_tokenizer(GFM_TO_TUNE)
print(f"\nInitializing AutoBench for benchmark: {BENCHMARK}")
bench = AutoBench(
benchmark=BENCHMARK,
model_name_or_path=model,
tokenizer=tokenizer,
overwrite=True,
trainer='native', # 'native' or 'accelerate'
autocast='fp16', # 'fp16', 'bf16', or 'fp32'
device='cuda',
)
lora_config = LORA_CONFIGS.get(GFM_TO_TUNE.split('/')[-1], None)
print(f"\nStarting LoRA fine-tuning for {GFM_TO_TUNE}...")
bench.run(
batch_size=BATCH_SIZE,
gradient_accumulation_steps=1,
patience=PATIENCE,
max_examples=MAX_EXAMPLES,
seeds=SEED,
epochs=EPOCHS,
lora_config=lora_config,
)
print("\n🎉 LoRA fine-tuning complete!")
print("Check the 'autobench_logs' and 'autobench_evaluations' directories for results.")
6. Multi-Model LoRA Fine-tuning (Optional)
print("Starting multi-model LoRA fine-tuning...")
print("="*50)
for gfm in AVAILABLE_GFMS:
try:
model, tokenizer = load_gfm_and_tokenizer(gfm)
bench = AutoBench(
benchmark=BENCHMARK,
model_name_or_path=model,
tokenizer=tokenizer,
overwrite=True,
trainer='native',
autocast='fp16',
device='cuda',
)
lora_config = LORA_CONFIGS.get(gfm.split('/')[-1], None)
print(f"\nStarting LoRA fine-tuning for {gfm}...")
bench.run(
batch_size=BATCH_SIZE,
patience=PATIENCE,
max_examples=MAX_EXAMPLES,
seeds=SEED,
epochs=EPOCHS,
lora_config=lora_config,
)
print(f"Finished fine-tuning for {gfm}.")
print("="*50)
except Exception as e:
print(f"An error occurred while processing {gfm}: {e}")
print("="*50)
continue
print("\n🎉 All models have been processed!")