X-Pert Perturbverse visualization

Imports

[21]:
# Core
import os, json, time, copy, warnings
from pathlib import Path
from typing import Dict

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
import matplotlib.pyplot as plt

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

# X-Pert package
import xpert
from xpert.models import TransformerGenerator
from xpert.loss import masked_relative_error, masked_mse_loss, masked_huber_loss
from xpert.external_model.scgpt.gene_tokenizer import GeneVocab
from xpert.external_model.gears.inference import compute_metrics
from xpert.data import Byte_Pert_Data

# Utils
from xpert.utils import fix_seed, merge_plot
from xpert.external_model.scgpt.util import set_seed, map_raw_id_to_vocab_id, add_file_handler


Data configuration & Model Hyperparameters

[3]:

# Reproducibility set_seed(42) fix_seed(2024) # Logger to file save_root = Path("./NormanWeissman2019_filtered/model_mode_scFlamingo_v18") save_root.mkdir(parents=True, exist_ok=True) logger = xpert.logger add_file_handler(logger, save_root / "run.log") logger.info(f"Running on {time.strftime('%Y-%m-%d %H:%M:%S')}") prefix = 'Replogle2022_K562' # Plot default plt.rcdefaults() warnings.filterwarnings("ignore")
INFO:xpert:Running on 2025-11-03 22:25:26
[15]:
# Training hyperparameters
batch_size = 64
eval_batch_size = 64
log_interval = 100

# Cell encoder hyperparameters
embsize = 512
d_hid = 512
nlayers = 12
nhead = 8
n_layers_cls = 3
dropout = 0.2
use_fast_transformer = True
amp = True
device_ids = [4, 5, 6, 7]
gpt_emb_dim = 1536

# add-token settings
load_encoder_plus = True  # copy encoder weights to encoder_plus when adding tokens
include_zero_gene = "all"


# Task settings

# X-Pert settings
delta_mode = False
attn_gate_mode = True
load_cxg_weight = True
mask_mode = True
pert_mode = 'gene'
pert_flag_mode = True
use_scgpt_layer = True
use_scgpt_input = True
add_token = True
init_mode = False
cross_mode = True

# Scheduler
epochs = 2
lr = 5e-5
scheduler_type = 'cosine_warm'


Load perturbation embeddings

[9]:
pert_data_pkl = '../../data/Replogle2022_K562/pert_data.pkl'
pert_data = pd.read_pickle(pert_data_pkl)
[10]:
# Load perturbation embeddings
if True:
    gpt_embed_root = Path('../../data/')
    gene_embed = pd.read_csv(gpt_embed_root / prefix / 'pert_embed.csv', sep=',', index_col=0)


pert_embed_dict: Dict[str, np.ndarray] = {}
np.random.seed(2024)
for pert in list(gene_embed.columns):
    if pert in gene_embed.columns:
        pert_embed_dict[pert] = gene_embed.loc[:, pert].values
    else:
        pert_embed_dict[pert] = gene_embed.loc[:, np.random.choice(gene_embed.columns, 1)[0]].values

Load model

[11]:
# Dataset and split
prefix = 'Replogle2022_K562'
add_control = False


# scGPT pretrained model
load_model_dir = Path('../../data/scGPT_human')
model_config_file = load_model_dir / 'args.json'
model_file = load_model_dir / 'best_model.pt'
vocab_file = load_model_dir / 'vocab.json'

[13]:
# Load vocab and extend if needed
vocab = GeneVocab.from_file(vocab_file)
special_tokens = ["<pad>", "<cls>", "<eoc>"]
for s in special_tokens:
    if s not in vocab:
        vocab.append_token(s)

vocab_ori = copy.deepcopy(vocab)
pert_data.adata = pert_data.adata_split
if add_token:
    # add non-overlap genes to the vocab
    add_genes = np.setdiff1d(pert_data.adata.var_names, list(vocab.get_stoi().keys()))
    for gene in add_genes:
        if gene not in vocab:
            vocab.append_token(gene)

# Mark genes present in vocab
pert_data.adata.var["id_in_vocab"] = [1 if gene in vocab else -1 for gene in pert_data.adata.var_names]
gene_ids_in_vocab = np.array(pert_data.adata.var["id_in_vocab"])
logger.info(f"match {np.sum(gene_ids_in_vocab >= 0)}/{len(gene_ids_in_vocab)} genes in vocabulary of size {len(vocab)}.")

# Build token ids for input genes
genes = pert_data.adata.var["gene_name"].tolist() if "gene_name" in pert_data.adata.var.columns else pert_data.adata.var_names.tolist()
vocab.set_default_index(vocab["<pad>"])
gene_ids = np.array([vocab[g] if g in vocab else vocab["<pad>"] for g in genes], dtype=int)
n_genes = len(genes)

# longer than Sequence length
max_seq_len = 6000
INFO:xpert:match 5642/5642 genes in vocabulary of size 60878.
[16]:
# Load model config
with open(model_config_file, 'r') as f:
    model_configs = json.load(f)
embsize = model_configs.get('embsize', embsize)
d_hid = model_configs.get('d_hid', d_hid)
n_layers_cls = model_configs.get('n_layers_cls', n_layers_cls)

# nhead = model_configs.get('nheads', nhead)
# nlayers = model_configs.get('nlayers', nlayers)

# Token count
ntokens = len(vocab_ori)  # using extended vocab if add_token True

# Model
model = TransformerGenerator(
    ntokens,
    embsize,
    nhead,
    d_hid,
    nlayers,
    nlayers_cls=n_layers_cls,
    n_cls=1,
    vocab=vocab,
    dropout=dropout,
    pad_token='<pad>',
    pad_value=0,
    pert_pad_id=2,

    use_fast_transformer=use_fast_transformer,
    pert_embed_dict=pert_embed_dict,
    gpt_emb_dim=gpt_emb_dim,
    # model_mode=model_mode,
    attn_gate_mode=attn_gate_mode,
    pert_mode=pert_mode,
    # drug_embed_mode=drug_embed_mode,
    pert_flag_mode=pert_flag_mode,
    use_scgpt_layer=use_scgpt_layer,
    use_scgpt_input=use_scgpt_input,
    mask_mode=mask_mode,
    add_token=add_token,
    init_mode=init_mode,
    # dosage_mode_type=dosage_mode_type,
    cross_mode=cross_mode,
)

model_file = f'/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/volc_result/scFlamingo/scPerturb_single_dataset/ReplogleWeissman2022_K562_essential/model_mode_scFlamingo_v13/model_best.pt'
# 加载保存的模型权重
saved_state_dict = torch.load(model_file)
from collections import OrderedDict
# 创建一个新的字典来存储修改后的权重
new_state_dict = OrderedDict()

# 修改键名称
for key, value in saved_state_dict.items():
    new_key = key.replace('module.', '')  # 移除'module.'前缀
    new_state_dict[new_key] = value

# 加载修改后的权重
missing_keys_info = model.load_state_dict(new_state_dict, strict=False)
print("Missing keys:", missing_keys_info.missing_keys)
print("Unexpected keys:", missing_keys_info.unexpected_keys)

# Device & DP
device = torch.device(f"cuda:{device_ids[0]}" if torch.cuda.is_available() else "cpu")
# Set device ids similar to the original notebook behavior
# If you have a preferred GPU list, set it here, e.g., device_ids = [4,5,6,7]
if torch.cuda.device_count() > 1:
    try:
        device_ids  # use pre-defined list if exists
    except NameError:
        device_ids = list(range(torch.cuda.device_count()))
    model = torch.nn.DataParallel(model, device_ids=device_ids).to(device)
else:
    model = model.to(device)


Using simple batchnorm instead of domain specific batchnorm
Missing keys: ['DoseFiLM.mlp.0.weight', 'DoseFiLM.mlp.0.bias', 'DoseFiLM.mlp.2.weight', 'DoseFiLM.mlp.2.bias']
Unexpected keys: []

Get Perturbverse embeddings

[23]:

def return_xpert_embed_addtoken(mode = 'percieve', gene_list = None): model.eval() adata_split = pert_data.adata_split adata_ctrl = adata_split[adata_split.obs['perturbation_new']=='control'] bs_infer = 32 np.random.seed(2025) adata_ctrl = adata_ctrl[np.random.choice(adata_ctrl.obs_names, 50, replace=False)] var_idx_dict = dict(zip(adata_ctrl.var_names, list(range(len(adata_ctrl.var_names))))) with torch.no_grad(): ave_embed_list = [] for gene in tqdm(gene_list): cell_graphs = [] pert_flags = torch.zeros(adata_ctrl.shape[1]) for _gene in gene.split('; '): if _gene not in var_idx_dict: continue else: pert_flags[var_idx_dict[gene]] = 1 for X, y in zip(adata_ctrl.X.toarray(), adata_ctrl.X.toarray()): pert_idx = [-1] # - if multiple perts if '; ' in gene: gears_pert = '+'.join(gene.split('; ')) # change to gears pert else: gears_pert = '+'.join(gene.split('; ')+['ctrl']) # change to gears pert de_idx = [-1] * 20 tmp_Data = Data(x=torch.Tensor(X.reshape(1,-1)), pert_idx=pert_idx, y=torch.Tensor(y.reshape(1,-1)), de_idx=de_idx, pert=gears_pert, pert_flags=pert_flags.reshape(1,-1)) cell_graphs.append(tmp_Data) infer_dataloader = DataLoader(cell_graphs, batch_size=bs_infer, shuffle=False, drop_last=False) embed_list = [] for batch, batch_data in enumerate(infer_dataloader): batch_size = len(batch_data.y) batch_data.to(device) x: torch.Tensor = batch_data.x # (batch_size * n_genes, 2) # ori_gene_values = x[:, 0].view(batch_size, n_genes) ori_gene_values = x # pert_flags = x[:, 1].long().view(batch_size, n_genes) pert_flags = batch_data.pert_flags.long() target_gene_values = batch_data.y # (batch_size, n_genes) batch_perts = batch_data.pert batch_perts = [[i.split('+')[0]] if 'ctrl' in i else i.split('+') for i in batch_perts] # we first focus on the single perts batch_pert_embed = torch.tensor(np.array([(pert_embed_dict[pert[0]]) for pert in batch_perts])).float() batch_pert_embed = batch_pert_embed.to(device) batch_perts = batch_data.pert batch_perts = [[i.split('+')[0]] if 'ctrl' in i else i.split('+') for i in batch_perts] # we first focus on the single perts max_pert_len = max([len(perts) for perts in batch_perts]) batch_pert_embed = torch.zeros(batch_size, max_pert_len, gpt_emb_dim).float() pert_mask = torch.ones(batch_size, max_pert_len) # batch_pert_embed = torch.tensor(np.array([(pert_embed_dict[pert[0]]) for pert in batch_perts])).float() for i, perts in enumerate(batch_perts): for j, pert in enumerate(perts): batch_pert_embed[i, j, :] = torch.tensor(np.array(pert_embed_dict[pert])) pert_mask[i, j] = 0 pert_mask = pert_mask.to(device) batch_pert_embed = batch_pert_embed.to(device) if include_zero_gene in ["all", "batch-wise"]: if include_zero_gene == "all": input_gene_ids = torch.arange(n_genes, device=device) else: # when batch-wise input_gene_ids = ( ori_gene_values.nonzero()[:, 1].flatten().unique().sort()[0] ) # sample input_gene_id if len(input_gene_ids) > max_seq_len: input_gene_ids = torch.randperm(len(input_gene_ids), device=device)[ :max_seq_len ] input_values = ori_gene_values[:, input_gene_ids] input_pert_flags = pert_flags[:, input_gene_ids] target_values = target_gene_values[:, input_gene_ids] if delta_mode: target_values = target_values - input_values mapped_input_gene_ids = map_raw_id_to_vocab_id(input_gene_ids, gene_ids) mapped_input_gene_ids = mapped_input_gene_ids.repeat(batch_size, 1) # src_key_padding_mask = mapped_input_gene_ids.eq(vocab[pad_token]) src_key_padding_mask = torch.zeros_like( input_values, dtype=torch.bool, device=input_values.device ) if mode == 'perceive': # - forward process self = model.module if self.add_token: src = self.encoder_plus(mapped_input_gene_ids) # (batch, seq_len, embsize) else: src = self.encoder(mapped_input_gene_ids) # (batch, seq_len, embsize) self.cur_gene_token_embs = src values = self.value_encoder(input_values) # (batch, seq_len, embsize) if self.pert_mode == 'gene': perts = self.pert_encoder(input_pert_flags) # (batch, seq_len, embsize) elif self.pert_mode == 'drug' and self.pert_flag_mode: perts = self.drug_pert_encoder(input_pert_flags) # (batch, seq_len, embsize) elif self.pert_mode == 'drug' and not self.pert_flag_mode: pass else: raise ValueError() if self.pert_flag_mode: total_embs = src + values + perts else: total_embs = src + values total_embs = self.bn(total_embs.permute(0, 2, 1)).permute(0, 2, 1) # - add flamingo cross attention if self.pert_mode == 'gene': batch_pert_embed = self.linear_1(batch_pert_embed) elif self.pert_mode == 'drug': batch_pert_embed = self.drug_mlp(batch_pert_embed) else: raise ValueError() # batch_pert_embed = rearrange(batch_pert_embed, 'b d -> b 1 d') # perceived = self.perceive(batch_pert_embed, mask=None) perceived = self.perceive(batch_pert_embed, mask=pert_mask) perceived = perceived.mean(dim=(1)).reshape(batch_size, -1) embed_list.append(perceived.cpu().detach().numpy()) elif mode == 'output': with torch.cuda.amp.autocast(enabled=amp): output_dict = model( mapped_input_gene_ids, input_values, input_pert_flags, src_key_padding_mask=src_key_padding_mask, batch_pert_embed = batch_pert_embed, CLS=CLS, CCE=CCE, MVC=MVC, ECS=ECS, do_sample=True, # pert_mask = None, pert_mask = pert_mask, batch_dosages_pad = None, ) output_values = output_dict["mlm_output"] embed_list.append(output_values.cpu().detach().numpy()) ave_embed_list.append(np.mean(np.vstack(embed_list), axis=0)) return np.vstack(ave_embed_list)

Load gene annotaion from original paper

[ ]:
import os.path as osp
save_dir_3 = '/nfs/public/lichen/results/single_cell_perturbation/perturbation_benchmark/perturbation_atlas/Replogle'
# 读取Excel文件的sheet5
df = pd.read_excel(osp.join(save_dir_3, 'mmc3.xlsx'), sheet_name='perturbation clusters')
[8]:

# 初始化结果字典 gene_to_pathway = {} # 遍历每一行 for _, row in df.iterrows(): gene_cell = row.loc['members'] pathway = row.loc['manual_annotation'] # 如果pathway是NaN,跳过 if pd.isna(pathway): continue # 拆分基因字符串并去除多余空格 genes = [gene.strip() for gene in str(gene_cell).split(',') if gene.strip()] # 映射到字典中 for gene in genes: gene_to_pathway[gene] = pathway
[ ]:
gene_list_K562 = np.array(pert_data.adata_ori.obs['perturbation_new'].value_counts().index)[1:]
print(len(np.intersect1d(gene_list_K562, list(gene_to_pathway.keys()))))

label_genes = (list(set(gene_list_K562)&set(gene_to_pathway.keys())))

504
504
[24]:
xpert_embed = return_xpert_embed_addtoken(mode = 'perceive',
                    gene_list = label_genes)
  1%|          | 3/504 [00:00<00:17, 29.05it/s]100%|██████████| 504/504 [00:14<00:00, 34.57it/s]

Visualize the Perturbverse

[26]:
import anndata as ad
from sklearn.model_selection import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
import scanpy as sc

name_list = ['perceive']
for i, _xpert_embed in enumerate([
    xpert_embed,
    ]):

    adata_sub = ad.AnnData(X=_xpert_embed,
                        obs = pd.DataFrame(index=label_genes))

    common_labels = [str(gene_to_pathway[pert]) if pert in gene_to_pathway else 'unknown' for pert in adata_sub.obs_names]
    adata_sub.obs['pathway'] = common_labels

    # - process data to plot umap
    min_mean       =0.0125
    max_mean       =3
    min_disp       =0.5
    n_neighbors    =10
    n_pcs          =30
    scale_mode = False


    # sc.pp.filter_cells(adata_sub, min_genes=200)
    # sc.pp.filter_genes(adata_sub, min_cells=3)

    # sc.pp.normalize_per_cell(adata_sub, key_n_counts='n_counts_all')
    # sc.pp.log1p(adata_sub)
    # sc.pp.highly_variable_genes(adata_sub, min_mean=min_mean, max_mean=max_mean, min_disp=min_disp)
    # sc.pp.highly_variable_genes(adata_sub, n_top_genes=2000)

    adata_sub.raw = adata_sub
    # adata_sub = adata_sub[:, adata_sub.var.highly_variable]
    if scale_mode:
        sc.pp.scale(adata_sub)
    # pca
    sc.tl.pca(adata_sub, svd_solver='arpack',random_state=2022)
    # neighbors
    sc.pp.neighbors(adata_sub, n_neighbors=n_neighbors, n_pcs=n_pcs,random_state=2022)
    # check umap
    sc.tl.umap(adata_sub,random_state=2022)

    point_size = 20
    sc.settings.set_figure_params(dpi=180, facecolor='white')
    plt.rcParams["figure.figsize"] = [3, 3]
    sc.pl.umap(adata_sub,
            color=[
                'pathway',
                            ],size=point_size,legend_fontsize=10,
            # color_map='magma',
            ncols=1, wspace=0.4,legend_loc='right margin',
            cmap='coolwarm',
            title = f'xpert predict - {name_list[i]}',
            )

    X_umap = adata_sub.obsm['X_umap']

    # 使用 KNN 分类器
    knn = KNeighborsClassifier(n_neighbors=5)
    scores = cross_val_score(knn, X_umap, common_labels, cv=5)

    print(f"Classification Accuracy (mean): {scores.mean():.3f}")
    print(f"Classification Accuracy (std): {scores.std():.3f}")
... storing 'pathway' as categorical
../_images/tutorials_plot_perturbverse_21_1.png
Classification Accuracy (mean): 0.593
Classification Accuracy (std): 0.021
[ ]: