X-Pert for Unseen Perturbation Prediction
This notebook is a refactored and streamlined version of the original workflow. It:
keeps the same behavior and outputs;
removes redundant code and paths;
organizes steps into clear sections with English markdown.
1. Imports
[2]:
# Core
import os, json, time, copy, warnings
from pathlib import Path
from typing import Dict
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt
# X-Pert package
import xpert
from xpert.models import TransformerGenerator
from xpert.loss import masked_relative_error, masked_mse_loss, masked_huber_loss
from xpert.external_model.scgpt.gene_tokenizer import GeneVocab
from xpert.external_model.gears.inference import compute_metrics
from xpert.data import Byte_Pert_Data
# Utils
from xpert.utils import fix_seed, merge_plot
from xpert.external_model.scgpt.util import set_seed, map_raw_id_to_vocab_id, add_file_handler
2. Global Settings and Logger
[3]:
# Reproducibility
set_seed(42)
fix_seed(2024)
# Logger to file
save_root = Path("./NormanWeissman2019_filtered/model_mode_scFlamingo_v18")
save_root.mkdir(parents=True, exist_ok=True)
logger = xpert.logger
add_file_handler(logger, save_root / "run.log")
logger.info(f"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}")
# Plot default
plt.rcdefaults()
warnings.filterwarnings("ignore")
INFO:xpert:Running on 2025-10-31 22:48:13
3. Data Configuration
[4]:
# Dataset and split
prefix = 'Norman2019'
add_control = False
# Paths
if True: # single env path
# data root for intermediate artifacts
# tmp_dir = Path('/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/scPerturb')
# save_dir = tmp_dir / prefix / 'GEARS_v2-prefix_NormanWeissman2019_filtered-pert_cell_filter_100-seed_2024-split_type_1-var_num_5000-num_de_genes_20-bs_train_32-bs_test_32'
# save_dir.mkdir(parents=True, exist_ok=True)
pass
# scGPT pretrained model
load_model_dir = Path('../../data/scGPT_human')
model_config_file = load_model_dir / 'args.json'
model_file = load_model_dir / 'best_model.pt'
vocab_file = load_model_dir / 'vocab.json'
4. Model Hyperparameters
Noted that we use nlayers = 2 here to accelerate training. In the original paper, we use nlayers = 12 to fully release the model performance.
[ ]:
# Training hyperparameters
batch_size = 64
eval_batch_size = 64
log_interval = 100
# Cell encoder hyperparameters
embsize = 512
d_hid = 512
nlayers = 2
nhead = 8
n_layers_cls = 3
dropout = 0.2
use_fast_transformer = True
amp = True
device_ids = [4, 5, 6, 7]
gpt_emb_dim = 1536
# add-token settings
load_encoder_plus = True # copy encoder weights to encoder_plus when adding tokens
include_zero_gene = "all"
# Task settings
# X-Pert settings
delta_mode = False
attn_gate_mode = True
load_cxg_weight = True
mask_mode = True
pert_mode = 'gene'
pert_flag_mode = True
use_scgpt_layer = True
use_scgpt_input = True
add_token = True
init_mode = False
cross_mode = True
# Scheduler
epochs = 2
lr = 5e-5
scheduler_type = 'cosine_warm'
5. Helper Losses and Training/Eval Routines
[6]:
# LR scheduler (cosine with warmup)
from torch.optim.lr_scheduler import LambdaLR
import math
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(current_step):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
progress = (current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
return LambdaLR(optimizer, lr_lambda)
# Training / Evaluation routines (behavior-preserving)
def train(model: nn.Module, train_loader: torch.utils.data.DataLoader) -> None:
model.train()
total_loss, total_mse = 0.0, 0.0
start_time = time.time()
num_batches = len(train_loader)
for batch, batch_data in enumerate(tqdm(train_loader)):
batch_size = len(batch_data.y)
batch_data.to(device)
x: torch.Tensor = batch_data.x
ori_gene_values = x
pert_flags = batch_data.pert_flags.long()
target_gene_values = batch_data.y
# prepare perturbation embeddings (single- or multi-)
batch_perts = batch_data.pert
batch_perts = [[i.split('+')[0]] if 'ctrl' in i else i.split('+') for i in batch_perts]
max_pert_len = max([len(perts) for perts in batch_perts])
batch_pert_embed = torch.zeros(batch_size, max_pert_len, gpt_emb_dim).float()
pert_mask = torch.ones(batch_size, max_pert_len)
for i, perts in enumerate(batch_perts):
for j, pert in enumerate(perts):
batch_pert_embed[i, j, :] = torch.tensor(np.array(pert_embed_dict[pert]))
pert_mask[i, j] = 0
batch_pert_embed = batch_pert_embed.to(device)
# select input genes
if include_zero_gene in ["all", "batch-wise"]:
if include_zero_gene == "all":
input_gene_ids = torch.arange(n_genes, device=device, dtype=torch.long)
else:
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
)
if len(input_gene_ids) > max_seq_len:
input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[:max_seq_len]
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
target_values = target_gene_values[:, input_gene_ids]
if delta_mode:
target_values = target_values - input_values
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device)
with torch.cuda.amp.autocast(enabled=amp):
output_dict = model(
mapped_input_gene_ids,
input_values,
input_pert_flags,
src_key_padding_mask=src_key_padding_mask,
batch_pert_embed=batch_pert_embed,
pert_mask=pert_mask,
batch_dosages_pad=None,
)
output_values = output_dict["mlm_output"]
masked_positions = torch.ones_like(input_values, dtype=torch.bool)
loss = loss_mse = criterion(output_values, target_values, masked_positions)
model.zero_grad()
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
if scheduler_type in ['cosine', 'cosine_warm']:
scheduler.step()
total_loss += loss.item()
total_mse += loss_mse.item()
if batch % log_interval == 0 and batch > 0:
lr_ = scheduler.get_last_lr()[0] if scheduler_type != 'steplr' else scheduler.get_last_lr()[0]
ms_per_batch = (time.time() - start_time) * 1000 / log_interval
cur_loss = total_loss / log_interval
cur_mse = total_mse / log_interval
logger.info(f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | lr {lr_:07.6f} | ms/batch {ms_per_batch:5.2f} | loss {cur_loss:5.5f} | mse {cur_mse:5.5f} |")
total_loss = 0
total_mse = 0
start_time = time.time()
@torch.no_grad()
def evaluate(model: nn.Module, val_loader: torch.utils.data.DataLoader):
model.eval()
total_loss = 0.0
total_error = 0.0
for batch, batch_data in enumerate(val_loader):
batch_size = len(batch_data.y)
batch_data.to(device)
x: torch.Tensor = batch_data.x
ori_gene_values = x
pert_flags = batch_data.pert_flags.long()
target_gene_values = batch_data.y
batch_perts = batch_data.pert
batch_perts = [[i.split('+')[0]] if 'ctrl' in i else i.split('+') for i in batch_perts]
max_pert_len = max([len(perts) for perts in batch_perts])
batch_pert_embed = torch.zeros(batch_size, max_pert_len, gpt_emb_dim).float()
pert_mask = torch.ones(batch_size, max_pert_len)
for i, perts in enumerate(batch_perts):
for j, pert in enumerate(perts):
batch_pert_embed[i, j, :] = torch.tensor(np.array(pert_embed_dict[pert]))
pert_mask[i, j] = 0
batch_pert_embed = batch_pert_embed.to(device)
if include_zero_gene in ["all", "batch-wise"]:
if include_zero_gene == "all":
input_gene_ids = torch.arange(n_genes, device=device)
else:
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
)
if len(input_gene_ids) > max_seq_len:
input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[:max_seq_len]
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
target_values = target_gene_values[:, input_gene_ids]
if delta_mode:
target_values = target_values - input_values
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=input_values.device)
with torch.cuda.amp.autocast(enabled=amp):
output_dict = model(
mapped_input_gene_ids,
input_values,
input_pert_flags,
src_key_padding_mask=src_key_padding_mask,
batch_pert_embed=batch_pert_embed,
pert_mask=pert_mask,
batch_dosages_pad=None,
# CLS=CLS, CCE=CCE, MVC=MVC, ECS=ECS, do_sample=True,
)
output_values = output_dict["mlm_output"]
masked_positions = torch.ones_like(input_values, dtype=torch.bool, device=input_values.device)
loss = criterion(output_values, target_values, masked_positions)
total_loss += loss.item()
total_error += masked_relative_error(output_values, target_values, masked_positions).item()
return total_loss / len(val_loader), total_error / len(val_loader)
@torch.no_grad()
def pred_perturb_new(model, batch_data, include_zero_gene="batch-wise", gene_ids=None, amp=True):
model.eval()
device = next(model.parameters()).device
batch_data.to(device)
batch_size = len(batch_data.pert)
x: torch.Tensor = batch_data.x
ori_gene_values = x
pert_flags = batch_data.pert_flags.long()
batch_perts = batch_data.pert
batch_perts = [[i.split('+')[0]] if 'ctrl' in i else i.split('+') for i in batch_perts]
max_pert_len = max([len(perts) for perts in batch_perts])
batch_pert_embed = torch.zeros(batch_size, max_pert_len, gpt_emb_dim).float()
pert_mask = torch.ones(batch_size, max_pert_len)
for i, perts in enumerate(batch_perts):
for j, pert in enumerate(perts):
batch_pert_embed[i, j, :] = torch.tensor(np.array(pert_embed_dict[pert]))
pert_mask[i, j] = 0
batch_pert_embed = batch_pert_embed.to(device)
if include_zero_gene in ["all", "batch-wise"]:
assert gene_ids is not None
if include_zero_gene == "all":
input_gene_ids = torch.arange(ori_gene_values.size(1), device=device)
else:
input_gene_ids = (
ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0]
)
input_values = ori_gene_values[:, input_gene_ids]
input_pert_flags = pert_flags[:, input_gene_ids]
mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids)
mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1)
src_key_padding_mask = torch.zeros_like(input_values, dtype=torch.bool, device=device)
with torch.cuda.amp.autocast(enabled=amp):
output_dict = model(
mapped_input_gene_ids,
input_values,
input_pert_flags,
src_key_padding_mask=src_key_padding_mask,
batch_pert_embed=batch_pert_embed,
pert_mask=pert_mask,
batch_dosages_pad=None,
# CLS=False, CCE=False, MVC=False, ECS=False, do_sample=True,
)
output_values = output_dict["mlm_output"].float()
pred_gene_values = torch.zeros_like(ori_gene_values)
pred_gene_values[:, input_gene_ids] = output_values
if delta_mode:
pred_gene_values = input_values + pred_gene_values
return pred_gene_values
[7]:
def eval_perturb_new(
loader,
model,
device: torch.device,
):
"""
Run model in inference mode using a given data loader.
Returns a dict with predictions on all genes and DE gene subsets, per the original logic.
"""
model.eval()
model.to(device)
pert_cat = []
pred = []
truth = []
pred_de = []
truth_de = []
results = {}
for itr, batch in enumerate(loader):
batch.to(device)
pert_cat.extend(batch.pert)
with torch.no_grad():
p = pred_perturb_new(model, batch, include_zero_gene, gene_ids=gene_ids)
t = batch.y
pred.extend(p.cpu())
truth.extend(t.cpu())
# Differentially expressed genes
for bidx, de_idx in enumerate(batch.de_idx):
pred_de.append(p[bidx, de_idx])
truth_de.append(t[bidx, de_idx])
# all genes
results["pert_cat"] = np.array(pert_cat)
pred = torch.stack(pred)
truth = torch.stack(truth)
results["pred"] = pred.detach().cpu().numpy().astype(np.float64)
results["truth"] = truth.detach().cpu().numpy().astype(np.float64)
# DE genes
pred_de = torch.stack(pred_de)
truth_de = torch.stack(truth_de)
results["pred_de"] = pred_de.detach().cpu().numpy().astype(np.float64)
results["truth_de"] = truth_de.detach().cpu().numpy().astype(np.float64)
return results
6. Data Loading and Tokenization
[8]:
# Build dataset via Byte_Pert_Data (behavior preserved)
pert_cell_filter = 100
seed = 2024
split_type = 1
split_ratio = [0.7, 0.2, 0.1]
var_num = 5000
num_de_genes = 20
bs_train = 80
bs_test = bs_train * 2
# Load preprocessed pert_data object (same as original)
# You can also reconstruct it by running the upstream data pipeline
pert_data_pkl = '../../data/Norman2019/pert_data.pkl'
pert_data = pd.read_pickle(pert_data_pkl)
# Build X-Pert training datasets
pert_data.get_Data_scgpt(
num_de_genes=pert_data.num_de_genes,
dataset_name=['train', 'test', 'val'],
add_control=add_control,
)
# GEARS-specific additions
pert_data.modify_gears()
# DataLoaders
trainloader, testloader, valloader = pert_data.get_dataloader(
mode='all', bs_train=int(bs_train)*len(device_ids), bs_test=int(bs_test)*len(device_ids)
)
100%|██████████| 156/156 [00:38<00:00, 4.02it/s]
100%|██████████| 45/45 [00:11<00:00, 3.99it/s]
100%|██████████| 23/23 [00:06<00:00, 3.74it/s]
========== get Data_scGPT finished!
add adata finished
add condition finished
add set2conditions finished
7. Perturbation Embeddings and Vocabulary
[9]:
# Load perturbation embeddings
if True:
gpt_embed_root = Path('../../data/')
gene_embed = pd.read_csv(gpt_embed_root / prefix / 'pert_embed.csv', sep=',', index_col=0)
# Collect all perturbation names
total_perts = []
for pert_list in [pert_data.train_perts, pert_data.test_perts, pert_data.val_perts]:
for pert in pert_list:
if ';' in pert:
total_perts.extend(pert.split('; '))
else:
total_perts.append(pert)
total_perts = np.unique(total_perts)
pert_embed_dict: Dict[str, np.ndarray] = {}
np.random.seed(2024)
for pert in total_perts:
if pert in gene_embed.columns:
pert_embed_dict[pert] = gene_embed.loc[:, pert].values
else:
pert_embed_dict[pert] = gene_embed.loc[:, np.random.choice(gene_embed.columns, 1)[0]].values
# Load vocab and extend if needed
vocab = GeneVocab.from_file(vocab_file)
special_tokens = ["<pad>", "<cls>", "<eoc>"]
for s in special_tokens:
if s not in vocab:
vocab.append_token(s)
vocab_ori = copy.deepcopy(vocab)
if add_token:
# add non-overlap genes to the vocab
add_genes = np.setdiff1d(pert_data.adata.var_names, list(vocab.get_stoi().keys()))
for gene in add_genes:
if gene not in vocab:
vocab.append_token(gene)
# Mark genes present in vocab
pert_data.adata.var["id_in_vocab"] = [1 if gene in vocab else -1 for gene in pert_data.adata.var_names]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
logger.info(f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes in vocabulary of size {len(vocab)}.")
# Build token ids for input genes
genes = pert_data.adata.var["gene_name"].tolist() if "gene_name" in pert_data.adata.var.columns else pert_data.adata.var_names.tolist()
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array([vocab[g] if g in vocab else vocab["<pad>"] for g in genes], dtype=int)
n_genes = len(genes)
# longer than Sequence length
max_seq_len = 6000
INFO:xpert:match 5040/5040 genes in vocabulary of size 61194.
8. Build Model, Optimizer and Scheduler
[ ]:
# Load model config
with open(model_config_file, 'r') as f:
model_configs = json.load(f)
embsize = model_configs.get('embsize', embsize)
d_hid = model_configs.get('d_hid', d_hid)
n_layers_cls = model_configs.get('n_layers_cls', n_layers_cls)
# nhead = model_configs.get('nheads', nhead)
# nlayers = model_configs.get('nlayers', nlayers)
# Token count
ntokens = len(vocab_ori) # using extended vocab if add_token True
# Model
model = TransformerGenerator(
ntokens,
embsize,
nhead,
d_hid,
nlayers,
nlayers_cls=n_layers_cls,
n_cls=1,
vocab=vocab,
dropout=dropout,
pad_token='<pad>',
pad_value=0,
pert_pad_id=2,
use_fast_transformer=use_fast_transformer,
pert_embed_dict=pert_embed_dict,
gpt_emb_dim=gpt_emb_dim,
# model_mode=model_mode,
attn_gate_mode=attn_gate_mode,
pert_mode=pert_mode,
# drug_embed_mode=drug_embed_mode,
pert_flag_mode=pert_flag_mode,
use_scgpt_layer=use_scgpt_layer,
use_scgpt_input=use_scgpt_input,
mask_mode=mask_mode,
add_token=add_token,
init_mode=init_mode,
# dosage_mode_type=dosage_mode_type,
cross_mode=cross_mode,
)
# Optionally load pretrained partial weights
if load_cxg_weight:
load_param_prefixs = ["encoder", "value_encoder", "transformer_encoder"]
state = torch.load(model_file)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in state.items() if any([k.startswith(p) for p in load_param_prefixs])}
model_dict.update(pretrained_dict)
missing = model.load_state_dict(model_dict, strict=False)
print("Missing keys:", missing.missing_keys)
print("Unexpected keys:", missing.unexpected_keys)
# If adding tokens, optionally load encoder weights into encoder_plus and init new tokens
if add_token and load_encoder_plus:
with torch.no_grad():
n = model.encoder.embedding.num_embeddings
# copy base encoder weights and norms
model.encoder_plus.embedding.weight[:n] = model.encoder.embedding.weight
model.encoder_plus.enc_norm.weight[:] = model.encoder.enc_norm.weight
model.encoder_plus.enc_norm.bias[:] = model.encoder.enc_norm.bias
# initialize newly added tokens with mean/std of pretrained embeddings
pretrained_embed = model.encoder.embedding.weight # (n, d)
mean = pretrained_embed.mean(dim=0)
std = pretrained_embed.std(dim=0)
m = model.encoder_plus.embedding.num_embeddings - n
if m > 0:
model.encoder_plus.embedding.weight[n:] = torch.normal(
mean=mean.expand(m, -1),
std=std.expand(m, -1)
)
# Device & DP
device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu")
# Set device ids similar to the original notebook behavior
# If you have a preferred GPU list, set it here, e.g., device_ids = [4,5,6,7]
if torch.cuda.device_count() > 1:
try:
device_ids # use pre-defined list if exists
except NameError:
device_ids = list(range(torch.cuda.device_count()))
model = torch.nn.DataParallel(model, device_ids=device_ids).to(device)
else:
model = model.to(device)
# Loss, optimizer, scheduler
loss_type = 'mse'
criterion = masked_mse_loss if loss_type == 'mse' else masked_huber_loss
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if scheduler_type == 'steplr':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, gamma=0.98)
elif scheduler_type == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=len(trainloader)*epochs)
elif scheduler_type == 'cosine_warm':
scheduler = get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=int(len(trainloader)*epochs*0.05),
num_training_steps=len(trainloader)*epochs,
)
else:
raise ValueError('Unknown scheduler')
scaler = torch.cuda.amp.GradScaler(enabled=amp)
Using simple batchnorm instead of domain specific batchnorm
Missing keys: []
Unexpected keys: ['transformer_encoder.layers.2.self_attn.Wqkv.weight', 'transformer_encoder.layers.2.self_attn.Wqkv.bias', 'transformer_encoder.layers.2.self_attn.out_proj.weight', 'transformer_encoder.layers.2.self_attn.out_proj.bias', 'transformer_encoder.layers.2.linear1.weight', 'transformer_encoder.layers.2.linear1.bias', 'transformer_encoder.layers.2.linear2.weight', 'transformer_encoder.layers.2.linear2.bias', 'transformer_encoder.layers.2.norm1.weight', 'transformer_encoder.layers.2.norm1.bias', 'transformer_encoder.layers.2.norm2.weight', 'transformer_encoder.layers.2.norm2.bias', 'transformer_encoder.layers.3.self_attn.Wqkv.weight', 'transformer_encoder.layers.3.self_attn.Wqkv.bias', 'transformer_encoder.layers.3.self_attn.out_proj.weight', 'transformer_encoder.layers.3.self_attn.out_proj.bias', 'transformer_encoder.layers.3.linear1.weight', 'transformer_encoder.layers.3.linear1.bias', 'transformer_encoder.layers.3.linear2.weight', 'transformer_encoder.layers.3.linear2.bias', 'transformer_encoder.layers.3.norm1.weight', 'transformer_encoder.layers.3.norm1.bias', 'transformer_encoder.layers.3.norm2.weight', 'transformer_encoder.layers.3.norm2.bias', 'transformer_encoder.layers.4.self_attn.Wqkv.weight', 'transformer_encoder.layers.4.self_attn.Wqkv.bias', 'transformer_encoder.layers.4.self_attn.out_proj.weight', 'transformer_encoder.layers.4.self_attn.out_proj.bias', 'transformer_encoder.layers.4.linear1.weight', 'transformer_encoder.layers.4.linear1.bias', 'transformer_encoder.layers.4.linear2.weight', 'transformer_encoder.layers.4.linear2.bias', 'transformer_encoder.layers.4.norm1.weight', 'transformer_encoder.layers.4.norm1.bias', 'transformer_encoder.layers.4.norm2.weight', 'transformer_encoder.layers.4.norm2.bias', 'transformer_encoder.layers.5.self_attn.Wqkv.weight', 'transformer_encoder.layers.5.self_attn.Wqkv.bias', 'transformer_encoder.layers.5.self_attn.out_proj.weight', 'transformer_encoder.layers.5.self_attn.out_proj.bias', 'transformer_encoder.layers.5.linear1.weight', 'transformer_encoder.layers.5.linear1.bias', 'transformer_encoder.layers.5.linear2.weight', 'transformer_encoder.layers.5.linear2.bias', 'transformer_encoder.layers.5.norm1.weight', 'transformer_encoder.layers.5.norm1.bias', 'transformer_encoder.layers.5.norm2.weight', 'transformer_encoder.layers.5.norm2.bias', 'transformer_encoder.layers.6.self_attn.Wqkv.weight', 'transformer_encoder.layers.6.self_attn.Wqkv.bias', 'transformer_encoder.layers.6.self_attn.out_proj.weight', 'transformer_encoder.layers.6.self_attn.out_proj.bias', 'transformer_encoder.layers.6.linear1.weight', 'transformer_encoder.layers.6.linear1.bias', 'transformer_encoder.layers.6.linear2.weight', 'transformer_encoder.layers.6.linear2.bias', 'transformer_encoder.layers.6.norm1.weight', 'transformer_encoder.layers.6.norm1.bias', 'transformer_encoder.layers.6.norm2.weight', 'transformer_encoder.layers.6.norm2.bias', 'transformer_encoder.layers.7.self_attn.Wqkv.weight', 'transformer_encoder.layers.7.self_attn.Wqkv.bias', 'transformer_encoder.layers.7.self_attn.out_proj.weight', 'transformer_encoder.layers.7.self_attn.out_proj.bias', 'transformer_encoder.layers.7.linear1.weight', 'transformer_encoder.layers.7.linear1.bias', 'transformer_encoder.layers.7.linear2.weight', 'transformer_encoder.layers.7.linear2.bias', 'transformer_encoder.layers.7.norm1.weight', 'transformer_encoder.layers.7.norm1.bias', 'transformer_encoder.layers.7.norm2.weight', 'transformer_encoder.layers.7.norm2.bias', 'transformer_encoder.layers.8.self_attn.Wqkv.weight', 'transformer_encoder.layers.8.self_attn.Wqkv.bias', 'transformer_encoder.layers.8.self_attn.out_proj.weight', 'transformer_encoder.layers.8.self_attn.out_proj.bias', 'transformer_encoder.layers.8.linear1.weight', 'transformer_encoder.layers.8.linear1.bias', 'transformer_encoder.layers.8.linear2.weight', 'transformer_encoder.layers.8.linear2.bias', 'transformer_encoder.layers.8.norm1.weight', 'transformer_encoder.layers.8.norm1.bias', 'transformer_encoder.layers.8.norm2.weight', 'transformer_encoder.layers.8.norm2.bias', 'transformer_encoder.layers.9.self_attn.Wqkv.weight', 'transformer_encoder.layers.9.self_attn.Wqkv.bias', 'transformer_encoder.layers.9.self_attn.out_proj.weight', 'transformer_encoder.layers.9.self_attn.out_proj.bias', 'transformer_encoder.layers.9.linear1.weight', 'transformer_encoder.layers.9.linear1.bias', 'transformer_encoder.layers.9.linear2.weight', 'transformer_encoder.layers.9.linear2.bias', 'transformer_encoder.layers.9.norm1.weight', 'transformer_encoder.layers.9.norm1.bias', 'transformer_encoder.layers.9.norm2.weight', 'transformer_encoder.layers.9.norm2.bias', 'transformer_encoder.layers.10.self_attn.Wqkv.weight', 'transformer_encoder.layers.10.self_attn.Wqkv.bias', 'transformer_encoder.layers.10.self_attn.out_proj.weight', 'transformer_encoder.layers.10.self_attn.out_proj.bias', 'transformer_encoder.layers.10.linear1.weight', 'transformer_encoder.layers.10.linear1.bias', 'transformer_encoder.layers.10.linear2.weight', 'transformer_encoder.layers.10.linear2.bias', 'transformer_encoder.layers.10.norm1.weight', 'transformer_encoder.layers.10.norm1.bias', 'transformer_encoder.layers.10.norm2.weight', 'transformer_encoder.layers.10.norm2.bias', 'transformer_encoder.layers.11.self_attn.Wqkv.weight', 'transformer_encoder.layers.11.self_attn.Wqkv.bias', 'transformer_encoder.layers.11.self_attn.out_proj.weight', 'transformer_encoder.layers.11.self_attn.out_proj.bias', 'transformer_encoder.layers.11.linear1.weight', 'transformer_encoder.layers.11.linear1.bias', 'transformer_encoder.layers.11.linear2.weight', 'transformer_encoder.layers.11.linear2.bias', 'transformer_encoder.layers.11.norm1.weight', 'transformer_encoder.layers.11.norm1.bias', 'transformer_encoder.layers.11.norm2.weight', 'transformer_encoder.layers.11.norm2.bias']
9. Training Loop and Validation
[11]:
# Training loop
best_val_loss = float('inf')
best_model = None
patience = 0
early_stop = epochs
train_metrics_list, train_metrics_pert_list = [], []
val_metrics_list, val_metrics_pert_list = [], []
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train_loader = pert_data.dataloader["train_loader"]
valid_loader = pert_data.dataloader["val_loader"]
train(model, train_loader)
val_loss, val_mre = evaluate(model, valid_loader)
elapsed = time.time() - epoch_start_time
logger.info("-" * 89)
logger.info(f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | valid loss/mse {val_loss:5.4f} |")
logger.info("-" * 89)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = copy.deepcopy(model)
logger.info(f"Best model with score {best_val_loss:5.4f}")
patience = 0
torch.save(best_model.state_dict(), save_root / "model_best.pt")
else:
patience += 1
if patience >= early_stop:
logger.info(f"Early stop at epoch {epoch}")
break
# Collect metrics for plotting later
# train_res = None
# val_res = None
# Optional: compute metrics per epoch (can be expensive). Keep minimal to preserve runtime.
# If needed:
train_res = eval_perturb_new(train_loader, model, device)
val_res = eval_perturb_new(valid_loader, model, device)
if train_res is not None:
train_metrics, train_metrics_pert = compute_metrics(train_res)
train_metrics_list.append(train_metrics)
train_metrics_pert_list.append(train_metrics_pert)
if val_res is not None:
val_metrics, val_metrics_pert = compute_metrics(val_res)
val_metrics_list.append(val_metrics)
val_metrics_pert_list.append(val_metrics_pert)
if scheduler_type == 'steplr':
scheduler.step()
48%|████▊ | 100/209 [01:26<01:24, 1.28it/s]INFO:xpert:| epoch 1 | 100/209 batches | lr 0.000045 | ms/batch 871.90 | loss 0.07585 | mse 0.07585 |
96%|█████████▌| 200/209 [02:49<00:07, 1.24it/s]INFO:xpert:| epoch 1 | 200/209 batches | lr 0.000029 | ms/batch 829.66 | loss 0.05854 | mse 0.05854 |
100%|██████████| 209/209 [02:56<00:00, 1.19it/s]
INFO:xpert:-----------------------------------------------------------------------------------------
INFO:xpert:| end of epoch 1 | time: 190.88s | valid loss/mse 0.0544 |
INFO:xpert:-----------------------------------------------------------------------------------------
INFO:xpert:Best model with score 0.0544
48%|████▊ | 100/209 [01:21<01:27, 1.24it/s]INFO:xpert:| epoch 2 | 100/209 batches | lr 0.000009 | ms/batch 823.23 | loss 0.05632 | mse 0.05632 |
96%|█████████▌| 200/209 [02:42<00:07, 1.24it/s]INFO:xpert:| epoch 2 | 200/209 batches | lr 0.000000 | ms/batch 806.34 | loss 0.05538 | mse 0.05538 |
100%|██████████| 209/209 [02:48<00:00, 1.24it/s]
INFO:xpert:-----------------------------------------------------------------------------------------
INFO:xpert:| end of epoch 2 | time: 183.47s | valid loss/mse 0.0536 |
INFO:xpert:-----------------------------------------------------------------------------------------
INFO:xpert:Best model with score 0.0536
10. Final Evaluation and Save Artifacts
[12]:
# Save best model and evaluate on test set
torch.save(best_model.state_dict(), save_root / "model_best.pt")
test_loader = pert_data.dataloader["test_loader"]
test_res = eval_perturb_new(test_loader, best_model, device)
# test_metrics, test_pert_res = compute_metrics(test_res) # optional
# Save results
model_prefix = f'lr_{lr}'
result_dir = save_root / f'result_{model_prefix}'
result_dir.mkdir(parents=True, exist_ok=True)
import pickle
pickle.dump(test_res, open(result_dir / 'test_res.pkl', 'wb'))
pickle.dump(train_metrics_list, open(result_dir / 'train_metrics_list.pkl', 'wb'))
pickle.dump(train_metrics_pert_list, open(result_dir / 'train_metrics_pert_list.pkl', 'wb'))
pickle.dump(val_metrics_list, open(result_dir / 'val_metrics_list.pkl', 'wb'))
pickle.dump(val_metrics_pert_list, open(result_dir / 'val_metrics_pert_list.pkl', 'wb'))
# Optional plotting if metrics were collected
try:
merge_plot(train_metrics_list, 'train', str(result_dir / 'train.png'))
merge_plot(val_metrics_list, 'test', str(result_dir / 'test.png'))
except Exception as e:
logger.warning(f"Plotting skipped: {e}")
# Free CUDA cache
torch.cuda.empty_cache()