X-Pert for Drug Perturbation Prediction
This notebook is adapted from the gene perturbation workflow for drug perturbation analysis. It:
adapts perturbation format to handle drug names and dosages
loads drug embeddings (GROVER, RDKit, ECFP, etc.) instead of gene embeddings
incorporates dosage information into the model
maintains the same clean structure and organization
1. Imports
[1]:
# Core
import os, json, time, copy, warnings
from pathlib import Path
from typing import Dict
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
# X-Pert package
import xpert
from xpert.models import TransformerGenerator
from xpert.loss import masked_relative_error, masked_mse_loss, masked_huber_loss
from xpert.external_model.scgpt.gene_tokenizer import GeneVocab
from xpert.external_model.gears.inference import compute_metrics
# from xpert.data import Byte_Pert_Data
# Utils
from xpert.utils import fix_seed, merge_plot
from xpert.external_model.scgpt.util import set_seed, map_raw_id_to_vocab_id, add_file_handler
2. Global Settings and Logger
[2]:
# Reproducibility
set_seed(42)
fix_seed(2024)
# Logger to file
save_root = Path("./L1000_phase1/model_mode_scFlamingo_drug_v1")
save_root.mkdir(parents=True, exist_ok=True)
logger = xpert.logger
add_file_handler(logger, save_root / "run.log")
logger.info(f"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}")
# Plot default
plt.rcdefaults()
warnings.filterwarnings("ignore")
INFO:xpert:Running on 2025-11-01 17:32:36
3. Data Configuration
[3]:
# Dataset and split
prefix = 'L1000_phase1' # or 'sciplex_3', etc.
add_control = False
# Data split column (for drug perturbation datasets)
# data_split_0: random split perts
# data_split_1: split drugs
# data_split_2: split cell types
split_col = 'data_split_1'
# Paths
# Drug perturbation data path
data_dir = Path('../../data') / prefix
pert_data_version = 'pert_data.pkl'
# Drug embedding directory
drug_embed_dir = data_dir
# scGPT pretrained model
load_model_dir = Path('../../data/scGPT_human')
model_config_file = load_model_dir / 'args.json'
model_file = load_model_dir / 'best_model.pt'
vocab_file = load_model_dir / 'vocab.json'
4. Model Hyperparameters
Note: For drug perturbation, we typically use fewer layers (nlayers = 4-12) depending on computational resources. Adjust device_ids based on available GPUs.
[4]:
# Training hyperparameters
batch_size = 64
eval_batch_size = 64
log_interval = 100
# Cell encoder hyperparameters
embsize = 512
d_hid = 512
nlayers = 2 # Can reduce to 2 for faster training
nhead = 8
n_layers_cls = 3
dropout = 0.2
use_fast_transformer = True
amp = True
device_ids = [0, 1, 2, 3] # Adjust based on available GPUs
# Drug embedding settings
drug_embed_mode = 'ecfp' # Options: 'grover', 'rdkit', 'morgan', 'ecfp', 'chemberta_st', 'grover+rdkit'
# Embedding dimension depends on the embedding method (will be loaded dynamically)
gpt_emb_dim = None # Will be set after loading embeddings
# add-token settings
load_encoder_plus = True # copy encoder weights to encoder_plus when adding tokens
include_zero_gene = "all"
# Task settings
# X-Pert settings for DRUG perturbation
pert_mode = 'drug' # Changed from 'gene'
drug_embed_mode = drug_embed_mode # Set above
pert_flag_mode = True # Use drug-gene interaction matrix when True
delta_mode = False
attn_gate_mode = True
load_cxg_weight = True
mask_mode = True
use_scgpt_layer = True
use_scgpt_input = True
add_token = True
init_mode = False
cross_mode = True
# Dosage settings
dosage_mode = True # Enable dosage information
dosage_mode_type = 1 # 0: direct multiply; 1: learnable vector; 2: additive with learnable
# Scheduler
epochs = 2
lr = 1e-4
scheduler_type = 'cosine_warm' # or 'steplr', 'cosine'
schedule_interval = 5 # for steplr
# Gradient clipping
max_norm = 1.0
5. Helper Losses and Training/Eval Routines
[5]:
# LR scheduler (cosine with warmup)
from torch.optim.lr_scheduler import LambdaLR
import math
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = (current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda)
# Training / Evaluation routines for DRUG perturbation
def train(model: nn.Module, train_loader: torch.utils.data.DataLoader) -> None:
model.train()
total_loss, total_mse = 0.0, 0.0
start_time = time.time()
num_batches = len(train_loader)
for batch, batch_data in enumerate(tqdm(train_loader)):
batch_size = len(batch_data.y)
batch_data.to(device)
x: torch.Tensor = batch_data.x
ori_gene_values = x
# Handle pert_flags for drug mode
if pert_mode == 'drug' and pert_flag_mode:
pert_flags = batch_data.pert_flags
else:
pert_flags = batch_data.pert_flags.long()
target_gene_values = batch_data.y
# Prepare drug perturbation embeddings and dosages
# Format: "drug1; drug2 | dosage1; dosage2"
batch_perts = batch_data.pert
batch_perts = [i.split(' | ')[0].split("; ") for i in batch_perts]
# Extract dosages and apply log10 transform
batch_dosages = [[np.log10(float(j)+1) for j in i.split(' | ')[1].split("; ")]
for i in batch_data.pert]
max_pert_len = max([len(perts) for perts in batch_perts])
batch_pert_embed = torch.zeros(batch_size, max_pert_len, gpt_emb_dim).float()
pert_mask = torch.ones(batch_size, max_pert_len)
batch_dosages_pad = torch.zeros(batch_size, max_pert_len).float()
for i, perts in enumerate(batch_perts):
for j, pert in enumerate(perts):
batch_pert_embed[i, j, :] = torch.tensor(np.array(pert_embed_dict[pert]))
pert_mask[i, j] = 0
batch_dosages_pad[i, j] = batch_dosages[i][j]
batch_pert_embed = batch_pert_embed.to(device)
batch_dosages_pad = batch_dosages_pad.to(device)
# Set to None if dosage_mode is False
if not dosage_mode:
batch_dosages_pad = None
# Select input genes
if include_zero_gene in ["all", "batch-wise"]:
if include_zero_gene == "all":
input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
else:
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
)
if len(input_gene_ids) > max_seq_len:
input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[:max_seq_len]
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
target_values = target_gene_values[:, input_gene_ids]
if delta_mode:
target_values = target_values - input_values
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device)
with torch.cuda.amp.autocast(enabled=amp):
output_dict = model(
mapped_input_gene_ids,
input_values,
input_pert_flags,
src_key_padding_mask=src_key_padding_mask,
batch_pert_embed=batch_pert_embed,
pert_mask=pert_mask,
batch_dosages_pad=batch_dosages_pad, # Now includes dosage information
)
output_values = output_dict["mlm_output"]
masked_positions = torch.ones_like(input_values, dtype=torch.bool)
loss = loss_mse = criterion(output_values, target_values, masked_positions)
model.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm, error_if_nonfinite=False if scaler.is_enabled() else True)
scaler.step(optimizer)
scaler.update()
if scheduler_type in ['cosine', 'cosine_warm']:
scheduler.step()
total_loss += loss.item()
total_mse += loss_mse.item()
if batch % log_interval == 0 and batch > 0:
lr_ = scheduler.get_last_lr()[0] if scheduler_type != 'steplr' else scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
cur_mse = total_mse / log_interval
logger.info(f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | lr {lr_:07.6f} | ms/batch {ms_per_batch:5.2f} | loss {cur_loss:5.5f} | mse {cur_mse:5.5f} |")
total_loss = 0
total_mse = 0
start_time = time.time()
@torch.no_grad()
def evaluate(model: nn.Module, val_loader: torch.utils.data.DataLoader):
model.eval()
total_loss = 0.0
total_error = 0.0
for batch, batch_data in enumerate(val_loader):
batch_size = len(batch_data.y)
batch_data.to(device)
x: torch.Tensor = batch_data.x
ori_gene_values = x
if pert_mode == 'drug' and pert_flag_mode:
pert_flags = batch_data.pert_flags
else:
pert_flags = batch_data.pert_flags.long()
target_gene_values = batch_data.y
batch_perts = batch_data.pert
batch_perts = [i.split(' | ')[0].split("; ") for i in batch_perts]
batch_dosages = [[np.log10(float(j)+1) for j in i.split(' | ')[1].split("; ")]
for i in batch_data.pert]
max_pert_len = max([len(perts) for perts in batch_perts])
batch_pert_embed = torch.zeros(batch_size, max_pert_len, gpt_emb_dim).float()
pert_mask = torch.ones(batch_size, max_pert_len)
batch_dosages_pad = torch.zeros(batch_size, max_pert_len).float()
for i, perts in enumerate(batch_perts):
for j, pert in enumerate(perts):
batch_pert_embed[i, j, :] = torch.tensor(np.array(pert_embed_dict[pert]))
pert_mask[i, j] = 0
batch_dosages_pad[i, j] = batch_dosages[i][j]
batch_pert_embed = batch_pert_embed.to(device)
batch_dosages_pad = batch_dosages_pad.to(device)
if not dosage_mode:
batch_dosages_pad = None
if include_zero_gene in ["all", "batch-wise"]:
if include_zero_gene == "all":
input_gene_ids = torch.arange(n_genes, device=device)
else:
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
)
if len(input_gene_ids) > max_seq_len:
input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[:max_seq_len]
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
target_values = target_gene_values[:, input_gene_ids]
if delta_mode:
target_values = target_values - input_values
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=input_values.device)
with torch.cuda.amp.autocast(enabled=amp):
output_dict = model(
mapped_input_gene_ids,
input_values,
input_pert_flags,
src_key_padding_mask=src_key_padding_mask,
batch_pert_embed=batch_pert_embed,
pert_mask=pert_mask,
batch_dosages_pad=batch_dosages_pad,
)
output_values = output_dict["mlm_output"]
masked_positions = torch.ones_like(input_values, dtype=torch.bool, device=input_values.device)
loss = criterion(output_values, target_values, masked_positions)
total_loss += loss.item()
total_error += masked_relative_error(output_values, target_values, masked_positions).item()
return total_loss / len(val_loader), total_error / len(val_loader)
@torch.no_grad()
def pred_perturb_new(model, batch_data, include_zero_gene="batch-wise", gene_ids=None, amp=True):
model.eval()
device = next(model.parameters()).device
batch_data.to(device)
batch_size = len(batch_data.pert)
x: torch.Tensor = batch_data.x
ori_gene_values = x
if pert_mode == 'drug' and pert_flag_mode:
pert_flags = batch_data.pert_flags
else:
pert_flags = batch_data.pert_flags.long()
batch_perts = batch_data.pert
batch_perts = [i.split(' | ')[0].split("; ") for i in batch_perts]
batch_dosages = [[np.log10(float(j)+1) for j in i.split(' | ')[1].split("; ")]
for i in batch_data.pert]
max_pert_len = max([len(perts) for perts in batch_perts])
batch_pert_embed = torch.zeros(batch_size, max_pert_len, gpt_emb_dim).float()
pert_mask = torch.ones(batch_size, max_pert_len)
batch_dosages_pad = torch.zeros(batch_size, max_pert_len).float()
for i, perts in enumerate(batch_perts):
for j, pert in enumerate(perts):
batch_pert_embed[i, j, :] = torch.tensor(np.array(pert_embed_dict[pert]))
pert_mask[i, j] = 0
batch_dosages_pad[i, j] = batch_dosages[i][j]
batch_pert_embed = batch_pert_embed.to(device)
batch_dosages_pad = batch_dosages_pad.to(device)
if not dosage_mode:
batch_dosages_pad = None
if include_zero_gene in ["all", "batch-wise"]:
assert gene_ids is not None
if include_zero_gene == "all":
input_gene_ids = torch.arange(ori_gene_values.size(1), device=device)
else:
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
)
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device)
with torch.cuda.amp.autocast(enabled=amp):
output_dict = model(
mapped_input_gene_ids,
input_values,
input_pert_flags,
src_key_padding_mask=src_key_padding_mask,
batch_pert_embed=batch_pert_embed,
pert_mask=pert_mask,
batch_dosages_pad=batch_dosages_pad,
)
output_values = output_dict["mlm_output"].float()
pred_gene_values = torch.zeros_like(ori_gene_values)
pred_gene_values[:, input_gene_ids] = output_values
if delta_mode:
pred_gene_values = input_values + pred_gene_values
return pred_gene_values
[6]:
def eval_perturb_new(
loader,
model,
device: torch.device,
):
"""
Run model in inference mode using a given data loader.
Returns a dict with predictions on all genes and DE gene subsets.
"""
model.eval()
model.to(device)
pert_cat = []
pred = []
truth = []
pred_de = []
truth_de = []
results = {}
for itr, batch in enumerate(loader):
batch.to(device)
pert_cat.extend(batch.pert)
with torch.no_grad():
p = pred_perturb_new(model, batch, include_zero_gene, gene_ids=gene_ids)
t = batch.y
pred.extend(p.cpu())
truth.extend(t.cpu())
# Differentially expressed genes
for bidx, de_idx in enumerate(batch.de_idx):
pred_de.append(p[bidx, de_idx])
truth_de.append(t[bidx, de_idx])
# all genes
results["pert_cat"] = np.array(pert_cat)
pred = torch.stack(pred)
truth = torch.stack(truth)
results["pred"] = pred.detach().cpu().numpy().astype(np.float64)
results["truth"] = truth.detach().cpu().numpy().astype(np.float64)
# DE genes
pred_de = torch.stack(pred_de)
truth_de = torch.stack(truth_de)
results["pred_de"] = pred_de.detach().cpu().numpy().astype(np.float64)
results["truth_de"] = truth_de.detach().cpu().numpy().astype(np.float64)
return results
6. Data Loading and Tokenization
[ ]:
# Build dataset via Byte_Pert_Data
seed = 2024
bs_train = 32
bs_test = bs_train * 2
# Load preprocessed pert_data object
import pickle
pert_data = pickle.load(open(data_dir / pert_data_version, 'rb'))
# pert_data = pickle.load(open("/nfs/public/lichen/results/single_cell_perturbation/perturbation_drug/data/L1000_phase1/pert_data_v1.pkl", 'rb'))
# Set var_genes
pert_data.var_genes = pert_data.adata_split.var_names
# Build X-Pert training datasets
# Note: get_Data_scgpt needs to be adapted for drug format with split_col
pert_data.get_Data_scgpt_2(
num_de_genes=pert_data.num_de_genes,
dataset_name=['train', 'test', 'val'],
add_control=add_control,
split_col=split_col, # Drug datasets use specific split columns
pert_flag_mode=pert_flag_mode, # For drug-gene interaction matrix
drug_embed_dir=drug_embed_dir,
)
100%|██████████| 101775/101775 [09:15<00:00, 183.18it/s]
100%|██████████| 26950/26950 [02:32<00:00, 176.48it/s]
0it [00:00, ?it/s]
========== get Data_scGPT finished!
[8]:
# Set adata
pert_data.adata = pert_data.adata_split
pert_data.adata.var["gene_name"] = pert_data.adata.var_names
# DataLoaders
trainloader, testloader, valloader = pert_data.get_dataloader(
mode='all',
bs_train=int(bs_train)*len(device_ids),
bs_test=int(bs_test)*len(device_ids)
)
[9]:
len(trainloader), len(testloader), len(valloader)
[9]:
(4029, 530, 0)
[10]:
len(pert_data.dataloader["train_loader"]), len(pert_data.dataloader["val_loader"]), len(pert_data.dataloader["test_loader"])
[10]:
(4029, 530, 530)
7. Drug Embeddings and Vocabulary
[11]:
# Load drug embeddings based on drug_embed_mode
if drug_embed_mode == 'grover':
pert_embed = pd.read_csv(drug_embed_dir / 'embed_grover.csv', sep=',', index_col=0)
elif drug_embed_mode == 'rdkit':
pert_embed = pd.read_csv(drug_embed_dir / 'embed_rdkit.csv', sep=',', index_col=0)
elif drug_embed_mode == 'morgan':
pert_embed = pd.read_csv(drug_embed_dir / 'embed_morgan.csv', sep=',', index_col=0)
elif drug_embed_mode == 'ecfp':
pert_embed = pd.read_csv(drug_embed_dir / 'embed_ecfp.csv', sep=',', index_col=0)
elif drug_embed_mode == 'chemberta_st':
pert_embed = pd.read_csv(drug_embed_dir / 'embed_chemberta_st.csv', sep=',', index_col=0)
elif drug_embed_mode == 'grover+rdkit':
pert_embed_1 = pd.read_csv(drug_embed_dir / 'embed_grover.csv', sep=',', index_col=0)
pert_embed_2 = pd.read_csv(drug_embed_dir / 'embed_rdkit.csv', sep=',', index_col=0)
pert_embed = pd.concat([pert_embed_1, pert_embed_2], axis=0)
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
pert_embed = pd.DataFrame(
scaler.fit_transform(pert_embed.T).T,
index=pert_embed.index,
columns=pert_embed.columns
)
else:
raise ValueError(f"Unknown drug_embed_mode: {drug_embed_mode}")
# Set embedding dimension
gpt_emb_dim = pert_embed.shape[0]
# Collect all drug names (from perturbation_group format: "drug1; drug2 | dosage1; dosage2")
total_perts = [i.split(' | ')[0] for i in pert_data.filter_perturbation_list]
# Split multi-drug perturbations
all_drugs = []
for pert_str in total_perts:
all_drugs.extend(pert_str.split('; '))
total_perts = np.unique(all_drugs)
# Create pert_embed_dict
pert_embed_dict: Dict[str, np.ndarray] = {}
np.random.seed(2024)
for pert in total_perts:
if pert in pert_embed.columns:
pert_embed_dict[pert] = pert_embed.loc[:, pert].values
else:
logger.warning(f'{pert} not in pert_embed, using random embedding')
pert_embed_dict[pert] = pert_embed.loc[:, np.random.choice(pert_embed.columns, 1)[0]].values
logger.info(f"Loaded {len(pert_embed_dict)} drug embeddings with dimension {gpt_emb_dim}")
# Load vocab and extend if needed
vocab = GeneVocab.from_file(vocab_file)
special_tokens = ["<pad>", "<cls>", "<eoc>"]
for s in special_tokens:
if s not in vocab:
vocab.append_token(s)
vocab_ori = copy.deepcopy(vocab)
if add_token:
# add non-overlap genes to the vocab
add_genes = np.setdiff1d(pert_data.adata.var_names, list(vocab.get_stoi().keys()))
for gene in add_genes:
if gene not in vocab:
vocab.append_token(gene)
# Mark genes present in vocab
pert_data.adata.var["id_in_vocab"] = [1 if gene in vocab else -1 for gene in pert_data.adata.var_names]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
logger.info(f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes in vocabulary of size {len(vocab)}.")
# Build token ids for input genes
genes = pert_data.adata.var["gene_name"].tolist()
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array([vocab[g] if g in vocab else vocab["<pad>"] for g in genes], dtype=int)
n_genes = len(genes)
# Sequence length
max_seq_len = 6000
INFO:xpert:Loaded 17200 drug embeddings with dimension 2048
INFO:xpert:match 978/978 genes in vocabulary of size 60726.
8. Build Model, Optimizer and Scheduler
[12]:
# Load model config
with open(model_config_file, 'r') as f:
model_configs = json.load(f)
embsize = model_configs.get('embsize', embsize)
d_hid = model_configs.get('d_hid', d_hid)
n_layers_cls = model_configs.get('n_layers_cls', n_layers_cls)
# Token count
ntokens = len(vocab_ori) # using extended vocab if add_token True
# Model
model = TransformerGenerator(
ntokens,
embsize,
nhead,
d_hid,
nlayers,
nlayers_cls=n_layers_cls,
n_cls=1,
vocab=vocab,
dropout=dropout,
pad_token='<pad>',
pad_value=0,
pert_pad_id=2,
use_fast_transformer=use_fast_transformer,
pert_embed_dict=pert_embed_dict,
gpt_emb_dim=gpt_emb_dim,
model_mode='scFlamingo_drug_v1', # or specify your model mode
attn_gate_mode=attn_gate_mode,
pert_mode=pert_mode,
drug_embed_mode=drug_embed_mode,
pert_flag_mode=pert_flag_mode,
use_scgpt_layer=use_scgpt_layer,
use_scgpt_input=use_scgpt_input,
mask_mode=mask_mode,
add_token=add_token,
init_mode=init_mode,
dosage_mode_type=dosage_mode_type,
cross_mode=cross_mode,
)
# Optionally load pretrained partial weights
if load_cxg_weight:
load_param_prefixs = ["encoder", "value_encoder", "transformer_encoder"]
state = torch.load(model_file)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in state.items() if any([k.startswith(p) for p in load_param_prefixs])}
model_dict.update(pretrained_dict)
missing = model.load_state_dict(model_dict, strict=False)
logger.info(f"Missing keys: {len(missing.missing_keys)}")
logger.info(f"Unexpected keys: {len(missing.unexpected_keys)}")
# If adding tokens, optionally load encoder weights into encoder_plus and init new tokens
if add_token and load_encoder_plus:
with torch.no_grad():
n = model.encoder.embedding.num_embeddings
# copy base encoder weights and norms
model.encoder_plus.embedding.weight[:n] = model.encoder.embedding.weight
model.encoder_plus.enc_norm.weight[:] = model.encoder.enc_norm.weight
model.encoder_plus.enc_norm.bias[:] = model.encoder.enc_norm.bias
# initialize newly added tokens with mean/std of pretrained embeddings
pretrained_embed = model.encoder.embedding.weight # (n, d)
mean = pretrained_embed.mean(dim=0)
std = pretrained_embed.std(dim=0)
m = model.encoder_plus.embedding.num_embeddings - n
if m > 0:
model.encoder_plus.embedding.weight[n:] = torch.normal(
mean=mean.expand(m, -1),
std=std.expand(m, -1)
)
# Device & DP
device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
try:
device_ids # use pre-defined list if exists
except NameError:
device_ids = list(range(torch.cuda.device_count()))
model = torch.nn.DataParallel(model, device_ids=device_ids).to(device)
else:
model = model.to(device)
# Loss, optimizer, scheduler
loss_type = 'mse'
criterion = masked_mse_loss if loss_type == 'mse' else masked_huber_loss
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if scheduler_type == 'steplr':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, schedule_interval, gamma=0.9)
elif scheduler_type == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(trainloader)*epochs)
elif scheduler_type == 'cosine_warm':
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=int(len(trainloader)*epochs*0.05),
num_training_steps=len(trainloader)*epochs,
)
else:
raise ValueError('Unknown scheduler')
scaler = torch.cuda.amp.GradScaler(enabled=amp)
Using simple batchnorm instead of domain specific batchnorm
INFO:xpert:Missing keys: 0
INFO:xpert:Unexpected keys: 120
9. Training Loop and Validation
[13]:
# Training loop
best_val_loss = float('inf')
best_model = None
patience = 0
early_stop = epochs
train_metrics_list, train_metrics_pert_list = [], []
val_metrics_list, val_metrics_pert_list = [], []
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train_loader = pert_data.dataloader["train_loader"]
valid_loader = pert_data.dataloader["val_loader"]
train(model, train_loader)
val_loss, val_mre = evaluate(model, valid_loader)
elapsed = time.time() - epoch_start_time
logger.info("-" * 89)
logger.info(f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | valid loss/mse {val_loss:5.4f} |")
logger.info("-" * 89)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = copy.deepcopy(model)
logger.info(f"Best model with score {best_val_loss:5.4f}")
patience = 0
torch.save(best_model.state_dict(), save_root / "model_best.pt")
else:
patience += 1
if patience >= early_stop:
logger.info(f"Early stop at epoch {epoch}")
break
# Check for NaN loss
if np.isnan(val_loss):
logger.warning(f"NaN loss detected at epoch {epoch}, stopping training")
break
# Collect metrics for plotting later (optional, can be expensive)
train_res = eval_perturb_new(train_loader, model, device)
val_res = eval_perturb_new(valid_loader, model, device)
if train_res is not None:
train_metrics, train_metrics_pert = compute_metrics(train_res)
train_metrics_list.append(train_metrics)
train_metrics_pert_list.append(train_metrics_pert)
if val_res is not None:
val_metrics, val_metrics_pert = compute_metrics(val_res)
val_metrics_list.append(val_metrics)
val_metrics_pert_list.append(val_metrics_pert)
if scheduler_type == 'steplr':
scheduler.step()
2%|▏ | 100/4029 [00:17<08:15, 7.94it/s]INFO:xpert:| epoch 1 | 100/4029 batches | lr 0.000025 | ms/batch 171.49 | loss 65.00415 | mse 65.00415 |
5%|▍ | 200/4029 [00:29<08:00, 7.98it/s]INFO:xpert:| epoch 1 | 200/4029 batches | lr 0.000050 | ms/batch 126.04 | loss 4.40624 | mse 4.40624 |
7%|▋ | 300/4029 [00:42<07:48, 7.96it/s]INFO:xpert:| epoch 1 | 300/4029 batches | lr 0.000075 | ms/batch 126.08 | loss 1.59986 | mse 1.59986 |
10%|▉ | 400/4029 [00:54<07:35, 7.96it/s]INFO:xpert:| epoch 1 | 400/4029 batches | lr 0.000100 | ms/batch 125.96 | loss 1.56468 | mse 1.56468 |
12%|█▏ | 500/4029 [01:07<07:20, 8.01it/s]INFO:xpert:| epoch 1 | 500/4029 batches | lr 0.000100 | ms/batch 125.36 | loss 1.52894 | mse 1.52894 |
15%|█▍ | 600/4029 [01:19<07:08, 7.99it/s]INFO:xpert:| epoch 1 | 600/4029 batches | lr 0.000100 | ms/batch 124.99 | loss 1.51555 | mse 1.51555 |
17%|█▋ | 700/4029 [01:32<06:55, 8.01it/s]INFO:xpert:| epoch 1 | 700/4029 batches | lr 0.000100 | ms/batch 124.94 | loss 1.48804 | mse 1.48804 |
20%|█▉ | 800/4029 [01:44<06:45, 7.97it/s]INFO:xpert:| epoch 1 | 800/4029 batches | lr 0.000099 | ms/batch 125.23 | loss 1.47788 | mse 1.47788 |
22%|██▏ | 900/4029 [01:57<06:30, 8.02it/s]INFO:xpert:| epoch 1 | 900/4029 batches | lr 0.000099 | ms/batch 125.18 | loss 1.45252 | mse 1.45252 |
25%|██▍ | 1000/4029 [02:09<06:17, 8.03it/s]INFO:xpert:| epoch 1 | 1000/4029 batches | lr 0.000098 | ms/batch 124.68 | loss 1.45935 | mse 1.45935 |
27%|██▋ | 1100/4029 [02:22<06:04, 8.03it/s]INFO:xpert:| epoch 1 | 1100/4029 batches | lr 0.000098 | ms/batch 124.91 | loss 1.44462 | mse 1.44462 |
30%|██▉ | 1200/4029 [02:34<05:55, 7.96it/s]INFO:xpert:| epoch 1 | 1200/4029 batches | lr 0.000097 | ms/batch 126.25 | loss 1.44614 | mse 1.44614 |
32%|███▏ | 1300/4029 [02:47<05:43, 7.95it/s]INFO:xpert:| epoch 1 | 1300/4029 batches | lr 0.000097 | ms/batch 125.06 | loss 1.43164 | mse 1.43164 |
35%|███▍ | 1400/4029 [03:00<05:36, 7.82it/s]INFO:xpert:| epoch 1 | 1400/4029 batches | lr 0.000096 | ms/batch 125.45 | loss 1.42629 | mse 1.42629 |
37%|███▋ | 1500/4029 [03:25<05:18, 7.93it/s] INFO:xpert:| epoch 1 | 1500/4029 batches | lr 0.000095 | ms/batch 251.57 | loss 1.41329 | mse 1.41329 |
40%|███▉ | 1600/4029 [03:37<05:08, 7.88it/s]INFO:xpert:| epoch 1 | 1600/4029 batches | lr 0.000094 | ms/batch 127.07 | loss 1.41402 | mse 1.41402 |
42%|████▏ | 1700/4029 [03:50<04:56, 7.85it/s]INFO:xpert:| epoch 1 | 1700/4029 batches | lr 0.000093 | ms/batch 127.20 | loss 1.39660 | mse 1.39660 |
45%|████▍ | 1800/4029 [04:03<04:42, 7.90it/s]INFO:xpert:| epoch 1 | 1800/4029 batches | lr 0.000092 | ms/batch 127.55 | loss 1.39432 | mse 1.39432 |
47%|████▋ | 1900/4029 [04:16<04:30, 7.86it/s]INFO:xpert:| epoch 1 | 1900/4029 batches | lr 0.000091 | ms/batch 127.03 | loss 1.35867 | mse 1.35867 |
50%|████▉ | 2000/4029 [04:28<04:19, 7.81it/s]INFO:xpert:| epoch 1 | 2000/4029 batches | lr 0.000090 | ms/batch 127.62 | loss 1.35443 | mse 1.35443 |
52%|█████▏ | 2100/4029 [04:41<04:04, 7.89it/s]INFO:xpert:| epoch 1 | 2100/4029 batches | lr 0.000088 | ms/batch 126.89 | loss 1.36290 | mse 1.36290 |
55%|█████▍ | 2200/4029 [04:54<03:54, 7.78it/s]INFO:xpert:| epoch 1 | 2200/4029 batches | lr 0.000087 | ms/batch 128.15 | loss 1.37226 | mse 1.37226 |
57%|█████▋ | 2300/4029 [05:07<03:41, 7.79it/s]INFO:xpert:| epoch 1 | 2300/4029 batches | lr 0.000086 | ms/batch 128.02 | loss 1.35059 | mse 1.35059 |
60%|█████▉ | 2400/4029 [05:19<03:27, 7.87it/s]INFO:xpert:| epoch 1 | 2400/4029 batches | lr 0.000084 | ms/batch 127.50 | loss 1.33494 | mse 1.33494 |
62%|██████▏ | 2500/4029 [05:32<03:14, 7.85it/s]INFO:xpert:| epoch 1 | 2500/4029 batches | lr 0.000083 | ms/batch 127.01 | loss 1.35183 | mse 1.35183 |
65%|██████▍ | 2600/4029 [05:45<03:01, 7.89it/s]INFO:xpert:| epoch 1 | 2600/4029 batches | lr 0.000081 | ms/batch 127.15 | loss 1.33104 | mse 1.33104 |
67%|██████▋ | 2700/4029 [05:57<02:47, 7.92it/s]INFO:xpert:| epoch 1 | 2700/4029 batches | lr 0.000079 | ms/batch 126.44 | loss 1.33160 | mse 1.33160 |
69%|██████▉ | 2800/4029 [06:10<02:35, 7.91it/s]INFO:xpert:| epoch 1 | 2800/4029 batches | lr 0.000078 | ms/batch 126.05 | loss 1.32298 | mse 1.32298 |
72%|███████▏ | 2900/4029 [06:23<02:22, 7.95it/s]INFO:xpert:| epoch 1 | 2900/4029 batches | lr 0.000076 | ms/batch 126.09 | loss 1.32695 | mse 1.32695 |
74%|███████▍ | 3000/4029 [06:35<02:09, 7.95it/s]INFO:xpert:| epoch 1 | 3000/4029 batches | lr 0.000074 | ms/batch 125.98 | loss 1.29774 | mse 1.29774 |
77%|███████▋ | 3100/4029 [06:48<01:56, 7.98it/s]INFO:xpert:| epoch 1 | 3100/4029 batches | lr 0.000072 | ms/batch 125.83 | loss 1.30883 | mse 1.30883 |
79%|███████▉ | 3200/4029 [07:00<01:43, 7.98it/s]INFO:xpert:| epoch 1 | 3200/4029 batches | lr 0.000070 | ms/batch 125.62 | loss 1.29395 | mse 1.29395 |
82%|████████▏ | 3300/4029 [07:25<01:33, 7.76it/s]INFO:xpert:| epoch 1 | 3300/4029 batches | lr 0.000069 | ms/batch 246.89 | loss 1.31994 | mse 1.31994 |
84%|████████▍ | 3400/4029 [07:38<01:19, 7.89it/s]INFO:xpert:| epoch 1 | 3400/4029 batches | lr 0.000067 | ms/batch 126.65 | loss 1.27893 | mse 1.27893 |
87%|████████▋ | 3500/4029 [07:51<01:07, 7.84it/s]INFO:xpert:| epoch 1 | 3500/4029 batches | lr 0.000065 | ms/batch 127.62 | loss 1.29397 | mse 1.29397 |
89%|████████▉ | 3600/4029 [08:03<00:55, 7.78it/s]INFO:xpert:| epoch 1 | 3600/4029 batches | lr 0.000063 | ms/batch 127.38 | loss 1.28557 | mse 1.28557 |
92%|█████████▏| 3700/4029 [08:16<00:41, 7.95it/s]INFO:xpert:| epoch 1 | 3700/4029 batches | lr 0.000061 | ms/batch 127.48 | loss 1.29124 | mse 1.29124 |
94%|█████████▍| 3800/4029 [08:29<00:28, 7.91it/s]INFO:xpert:| epoch 1 | 3800/4029 batches | lr 0.000059 | ms/batch 127.04 | loss 1.28448 | mse 1.28448 |
97%|█████████▋| 3900/4029 [08:41<00:16, 7.96it/s]INFO:xpert:| epoch 1 | 3900/4029 batches | lr 0.000057 | ms/batch 126.53 | loss 1.29239 | mse 1.29239 |
99%|█████████▉| 4000/4029 [08:54<00:03, 7.91it/s]INFO:xpert:| epoch 1 | 4000/4029 batches | lr 0.000055 | ms/batch 126.29 | loss 1.30030 | mse 1.30030 |
100%|██████████| 4029/4029 [08:58<00:00, 7.49it/s]
INFO:xpert:-----------------------------------------------------------------------------------------
INFO:xpert:| end of epoch 1 | time: 580.58s | valid loss/mse 1.2698 |
INFO:xpert:-----------------------------------------------------------------------------------------
INFO:xpert:Best model with score 1.2698
2%|▏ | 100/4029 [00:13<08:24, 7.79it/s]INFO:xpert:| epoch 2 | 100/4029 batches | lr 0.000052 | ms/batch 131.97 | loss 1.28284 | mse 1.28284 |
5%|▍ | 200/4029 [00:25<08:12, 7.78it/s]INFO:xpert:| epoch 2 | 200/4029 batches | lr 0.000050 | ms/batch 128.38 | loss 1.27764 | mse 1.27764 |
7%|▋ | 300/4029 [00:38<07:58, 7.79it/s]INFO:xpert:| epoch 2 | 300/4029 batches | lr 0.000048 | ms/batch 128.84 | loss 1.28330 | mse 1.28330 |
10%|▉ | 400/4029 [00:51<07:42, 7.85it/s]INFO:xpert:| epoch 2 | 400/4029 batches | lr 0.000046 | ms/batch 127.93 | loss 1.27734 | mse 1.27734 |
12%|█▏ | 500/4029 [01:04<07:28, 7.88it/s]INFO:xpert:| epoch 2 | 500/4029 batches | lr 0.000044 | ms/batch 127.14 | loss 1.25229 | mse 1.25229 |
15%|█▍ | 600/4029 [01:29<07:18, 7.82it/s] INFO:xpert:| epoch 2 | 600/4029 batches | lr 0.000042 | ms/batch 256.45 | loss 1.27088 | mse 1.27088 |
17%|█▋ | 700/4029 [01:42<07:04, 7.85it/s]INFO:xpert:| epoch 2 | 700/4029 batches | lr 0.000040 | ms/batch 128.22 | loss 1.24518 | mse 1.24518 |
20%|█▉ | 800/4029 [01:55<06:51, 7.85it/s]INFO:xpert:| epoch 2 | 800/4029 batches | lr 0.000038 | ms/batch 126.93 | loss 1.25825 | mse 1.25825 |
22%|██▏ | 900/4029 [02:08<06:39, 7.84it/s]INFO:xpert:| epoch 2 | 900/4029 batches | lr 0.000036 | ms/batch 129.29 | loss 1.25058 | mse 1.25058 |
25%|██▍ | 1000/4029 [02:21<06:32, 7.71it/s]INFO:xpert:| epoch 2 | 1000/4029 batches | lr 0.000034 | ms/batch 128.92 | loss 1.24740 | mse 1.24740 |
27%|██▋ | 1100/4029 [02:34<06:13, 7.83it/s]INFO:xpert:| epoch 2 | 1100/4029 batches | lr 0.000032 | ms/batch 128.40 | loss 1.26223 | mse 1.26223 |
30%|██▉ | 1200/4029 [02:46<06:00, 7.85it/s]INFO:xpert:| epoch 2 | 1200/4029 batches | lr 0.000030 | ms/batch 127.96 | loss 1.24876 | mse 1.24876 |
32%|███▏ | 1300/4029 [02:59<05:59, 7.58it/s]INFO:xpert:| epoch 2 | 1300/4029 batches | lr 0.000028 | ms/batch 128.64 | loss 1.24424 | mse 1.24424 |
35%|███▍ | 1400/4029 [03:12<05:38, 7.77it/s]INFO:xpert:| epoch 2 | 1400/4029 batches | lr 0.000026 | ms/batch 128.46 | loss 1.24871 | mse 1.24871 |
37%|███▋ | 1500/4029 [03:25<05:20, 7.90it/s]INFO:xpert:| epoch 2 | 1500/4029 batches | lr 0.000025 | ms/batch 127.30 | loss 1.24852 | mse 1.24852 |
40%|███▉ | 1600/4029 [03:38<05:07, 7.91it/s]INFO:xpert:| epoch 2 | 1600/4029 batches | lr 0.000023 | ms/batch 126.50 | loss 1.24752 | mse 1.24752 |
42%|████▏ | 1700/4029 [03:50<04:54, 7.91it/s]INFO:xpert:| epoch 2 | 1700/4029 batches | lr 0.000021 | ms/batch 126.44 | loss 1.23685 | mse 1.23685 |
45%|████▍ | 1800/4029 [04:03<04:40, 7.95it/s]INFO:xpert:| epoch 2 | 1800/4029 batches | lr 0.000019 | ms/batch 126.10 | loss 1.25094 | mse 1.25094 |
47%|████▋ | 1900/4029 [04:15<04:30, 7.86it/s]INFO:xpert:| epoch 2 | 1900/4029 batches | lr 0.000018 | ms/batch 126.22 | loss 1.23449 | mse 1.23449 |
50%|████▉ | 2000/4029 [04:28<04:16, 7.92it/s]INFO:xpert:| epoch 2 | 2000/4029 batches | lr 0.000016 | ms/batch 126.30 | loss 1.22911 | mse 1.22911 |
52%|█████▏ | 2100/4029 [04:41<04:02, 7.97it/s]INFO:xpert:| epoch 2 | 2100/4029 batches | lr 0.000015 | ms/batch 125.58 | loss 1.23448 | mse 1.23448 |
55%|█████▍ | 2200/4029 [04:53<03:49, 7.97it/s]INFO:xpert:| epoch 2 | 2200/4029 batches | lr 0.000013 | ms/batch 125.53 | loss 1.24162 | mse 1.24162 |
57%|█████▋ | 2300/4029 [05:06<03:36, 7.98it/s]INFO:xpert:| epoch 2 | 2300/4029 batches | lr 0.000012 | ms/batch 125.33 | loss 1.23597 | mse 1.23597 |
60%|█████▉ | 2400/4029 [05:30<03:24, 7.98it/s] INFO:xpert:| epoch 2 | 2400/4029 batches | lr 0.000011 | ms/batch 240.08 | loss 1.24450 | mse 1.24450 |
62%|██████▏ | 2500/4029 [05:42<03:12, 7.95it/s]INFO:xpert:| epoch 2 | 2500/4029 batches | lr 0.000010 | ms/batch 125.79 | loss 1.23866 | mse 1.23866 |
65%|██████▍ | 2600/4029 [05:55<02:59, 7.95it/s]INFO:xpert:| epoch 2 | 2600/4029 batches | lr 0.000008 | ms/batch 125.71 | loss 1.24764 | mse 1.24764 |
67%|██████▋ | 2700/4029 [06:07<02:47, 7.94it/s]INFO:xpert:| epoch 2 | 2700/4029 batches | lr 0.000007 | ms/batch 125.76 | loss 1.23106 | mse 1.23106 |
69%|██████▉ | 2800/4029 [06:20<02:34, 7.95it/s]INFO:xpert:| epoch 2 | 2800/4029 batches | lr 0.000006 | ms/batch 125.84 | loss 1.22217 | mse 1.22217 |
72%|███████▏ | 2900/4029 [06:33<02:22, 7.94it/s]INFO:xpert:| epoch 2 | 2900/4029 batches | lr 0.000005 | ms/batch 125.71 | loss 1.21713 | mse 1.21713 |
74%|███████▍ | 3000/4029 [06:45<02:09, 7.96it/s]INFO:xpert:| epoch 2 | 3000/4029 batches | lr 0.000004 | ms/batch 125.72 | loss 1.22231 | mse 1.22231 |
77%|███████▋ | 3100/4029 [06:58<01:56, 7.96it/s]INFO:xpert:| epoch 2 | 3100/4029 batches | lr 0.000004 | ms/batch 125.68 | loss 1.22234 | mse 1.22234 |
79%|███████▉ | 3200/4029 [07:10<01:44, 7.95it/s]INFO:xpert:| epoch 2 | 3200/4029 batches | lr 0.000003 | ms/batch 125.63 | loss 1.21855 | mse 1.21855 |
82%|████████▏ | 3300/4029 [07:23<01:31, 7.95it/s]INFO:xpert:| epoch 2 | 3300/4029 batches | lr 0.000002 | ms/batch 125.79 | loss 1.22944 | mse 1.22944 |
84%|████████▍ | 3400/4029 [07:35<01:19, 7.95it/s]INFO:xpert:| epoch 2 | 3400/4029 batches | lr 0.000002 | ms/batch 125.79 | loss 1.21225 | mse 1.21225 |
87%|████████▋ | 3500/4029 [07:48<01:06, 7.98it/s]INFO:xpert:| epoch 2 | 3500/4029 batches | lr 0.000001 | ms/batch 125.37 | loss 1.22758 | mse 1.22758 |
89%|████████▉ | 3600/4029 [08:01<00:53, 7.98it/s]INFO:xpert:| epoch 2 | 3600/4029 batches | lr 0.000001 | ms/batch 125.28 | loss 1.23964 | mse 1.23964 |
92%|█████████▏| 3700/4029 [08:13<00:41, 7.97it/s]INFO:xpert:| epoch 2 | 3700/4029 batches | lr 0.000000 | ms/batch 125.40 | loss 1.24234 | mse 1.24234 |
94%|█████████▍| 3800/4029 [08:26<00:28, 7.97it/s]INFO:xpert:| epoch 2 | 3800/4029 batches | lr 0.000000 | ms/batch 125.62 | loss 1.21881 | mse 1.21881 |
97%|█████████▋| 3900/4029 [08:38<00:16, 7.97it/s]INFO:xpert:| epoch 2 | 3900/4029 batches | lr 0.000000 | ms/batch 125.30 | loss 1.22792 | mse 1.22792 |
99%|█████████▉| 4000/4029 [08:51<00:03, 7.81it/s]INFO:xpert:| epoch 2 | 4000/4029 batches | lr 0.000000 | ms/batch 127.13 | loss 1.24660 | mse 1.24660 |
100%|██████████| 4029/4029 [08:55<00:00, 7.53it/s]
INFO:xpert:-----------------------------------------------------------------------------------------
INFO:xpert:| end of epoch 2 | time: 586.06s | valid loss/mse 1.2105 |
INFO:xpert:-----------------------------------------------------------------------------------------
INFO:xpert:Best model with score 1.2105
[14]:
len(valid_loader)
[14]:
530
[ ]:
10. Final Evaluation and Save Artifacts
[15]:
# Save best model and evaluate on test set
torch.save(best_model.state_dict(), save_root / "model_best.pt")
test_loader = pert_data.dataloader["test_loader"]
test_res = eval_perturb_new(test_loader, best_model, device)
# test_metrics, test_pert_res = compute_metrics(test_res) # optional
# Save results
model_prefix = f'lr_{lr}_dosage_{dosage_mode_type}'
result_dir = save_root / f'result_{model_prefix}'
result_dir.mkdir(parents=True, exist_ok=True)
import pickle
pickle.dump(test_res, open(result_dir / 'test_res.pkl', 'wb'))
if len(train_metrics_list) > 0:
pickle.dump(train_metrics_list, open(result_dir / 'train_metrics_list.pkl', 'wb'))
pickle.dump(train_metrics_pert_list, open(result_dir / 'train_metrics_pert_list.pkl', 'wb'))
pickle.dump(val_metrics_list, open(result_dir / 'val_metrics_list.pkl', 'wb'))
pickle.dump(val_metrics_pert_list, open(result_dir / 'val_metrics_pert_list.pkl', 'wb'))
# Optional plotting if metrics were collected
if len(train_metrics_list) > 0:
try:
merge_plot(train_metrics_list, 'train', str(result_dir / 'train.png'))
merge_plot(val_metrics_list, 'val', str(result_dir / 'val.png'))
except Exception as e:
logger.warning(f"Plotting skipped: {e}")
# Free CUDA cache
torch.cuda.empty_cache()