Create chemical perturabation data

This notebook is used to create perturbation data for X-Pert training

1. import the necessary libraries

[9]:
import xpert as xp
from xpert.data.utils import get_info_txt
import scanpy as sc
import numpy as np
from types import MethodType
from tqdm import tqdm
import os
import pickle
import warnings
import logging
warnings.filterwarnings('ignore')
sc.settings.verbosity = 0
logging.getLogger('scanpy').setLevel(logging.ERROR)
logging.getLogger('anndata').setLevel(logging.ERROR)

2. load original adata file of perturbation data

[4]:
adata = sc.read('../../data/L1000_phase1/adata_L1000_phase1.h5ad')
adata
[4]:
AnnData object with n_obs × n_vars = 678401 × 978
    obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id', 'canonical_smiles', 'plate_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

add obs.column with ‘perturbation_new’ and ‘cell_type_new’, which is used to perform data proprecessing.

[5]:
def get_perturbation_group(x):
    if x['pert_iname'] == 'DMSO':
        return ' | '.join(['control', x['dose'], x['cell_id']])
    return ' | '.join([x['pert_iname'], x['dose'], x['cell_id']])

def get_perturbation_new(x):
    if x['pert_iname'] == 'DMSO':
        return 'control'
    else:
        return x['pert_iname']

def get_cell_type_new(x):
    return x['cell_id']

adata.obs['dose'] = adata.obs['pert_dose']

perturbation_group = adata.obs.apply(get_perturbation_group, axis=1)
adata.obs['perturbation_group'] = perturbation_group

perturbation_new = adata.obs.apply(get_perturbation_new, axis=1)
adata.obs['perturbation_new'] = perturbation_new

celltype_new = adata.obs.apply(get_cell_type_new, axis=1)
adata.obs['celltype_new'] = celltype_new
[ ]:

[ ]:

3. generate Pert_Data object

[14]:
# parameters

pert_cell_filter = 0 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.8, 0.2, 0] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
data_dir = '../../data/L1000_phase1/'
[15]:
# Create Pert_Data object
pert_data = xp.data.Byte_Pert_Data(
    prefix='L1000_phase1',
    pert_cell_filter=pert_cell_filter,
    seed=seed,
    split_ratio=split_ratio,
    split_type=split_type,
    var_num=var_num,
    num_de_genes=num_de_genes,
    bs_train=bs_train,
    bs_test=bs_test
)

# Complete data processing pipeline
print("Step 1: Reading files...")
pert_data.read_files(adata)

# Get the filter_perturbation_list
pert_data.adata_split = pert_data.adata_ori
tmp_obs = pert_data.adata_split[pert_data.adata_split.obs['perturbation_new']!='control'].obs
pert_data.filter_perturbation_list = list(tmp_obs['perturbation_group'].unique()) # record the perturbation pair


Step 1: Reading files...
========== read file finished!

4. Rewrite cell pair function

[16]:
def set_control_barcode(self):
    """
    this function is used to set control_barcode for each pert
    """

    # - get pert_index_dict
    obs_df = pert_data.adata_split.obs
    obs_df['control_barcode'] = 'None'

    pert_index_dict = {}

    for i in tqdm(range(len(obs_df))):
        palte_id, cell_type, pert_time, perturbation = obs_df.loc[str(i), 'plate_id'], obs_df.loc[str(i), 'celltype_new'], obs_df.loc[str(i), 'pert_time'], obs_df.loc[str(i), 'perturbation_new']
        if perturbation == 'control':
            _key = ' | '.join([palte_id, cell_type, str(pert_time)])
            if _key in pert_index_dict:
                pert_index_dict[_key].append(str(i))
            else:
                pert_index_dict[_key] = []
                pert_index_dict[_key].append(str(i))
        i += 1

    # - set control_barcode
    np.random.seed(pert_data.seed)
    for i in tqdm(range(len(obs_df))):
        palte_id, cell_type, pert_time, perturbation = obs_df.loc[str(i), 'plate_id'], obs_df.loc[str(i), 'celltype_new'], obs_df.loc[str(i), 'pert_time'], obs_df.loc[str(i), 'perturbation_new']
        if perturbation != 'control':
            _key = ' | '.join([palte_id, cell_type, str(pert_time)])
            pair_control_obs = np.random.choice(pert_index_dict[_key], 1, replace=True)
            obs_df.loc[str(i),'control_barcode'] = pair_control_obs[0]

    pert_data.adata_split.obs = obs_df

def get_de_genes(self,
                rankby_abs = True,
                key_added = 'rank_genes_groups'):
    gene_dict = {}
    pvals_dict, pvals_adj_dict, scores_dict, logfoldchanges_dict = {}, {}, {}, {}

    for pert in tqdm(self.filter_perturbation_list):


        gene_dict[pert] = list(self.adata_split.var_names)
        pvals_dict[pert] = [0.1]*len(self.adata_split.var_names)
        pvals_adj_dict[pert] = [0.1]*len(self.adata_split.var_names)
        scores_dict[pert] = [0.1]*len(self.adata_split.var_names)
        logfoldchanges_dict[pert] = [0.1]*len(self.adata_split.var_names)

    self.adata_split.uns[key_added] = gene_dict
    self.adata_split.uns['pvals'] = pvals_dict
    self.adata_split.uns['pvals_adj'] = pvals_adj_dict
    self.adata_split.uns['scores'] = scores_dict
    self.adata_split.uns['logfoldchanges'] = logfoldchanges_dict
    print('='*10,f'get de genes finished!')

pert_data.set_control_barcode = MethodType(set_control_barcode, pert_data)
[ ]:

5. Run process

[17]:


# print("Step 2: Filtering perturbations...") # pert_data.filter_perturbation() # print("Step 3: Preprocessing adata and selecting HVGs...") # pert_data.get_and_process_adata(var_num=pert_data.var_num) print("Step 2: Setting control barcodes...") pert_data.set_control_barcode() # print("Step 5: Calculating E-distances...") # pert_data.get_edis_2() # print("Step 6: Filtering sgRNAs...") # pert_data.adata_split.obs['sgRNA_new'] = 'control' # pert_data.filter_sgRNA() print("Step 3: Data splitting...") pert_data.data_split_2(split_type = 0, test_perts = None) pert_data.data_split_2(split_type = 1, test_perts = None) pert_data.data_split_2(split_type = 2, test_perts = None) print("Step 8: Getting differential genes...") get_de_genes(pert_data) print("Step 9: Saving processed data...") # Save the processed data pickle.dump(pert_data, open(os.path.join(data_dir, 'pert_data.pkl'), 'wb')) print("Pert_Data object generation completed!") print(f"Final dataset shape: {pert_data.adata_split.shape}") print(f"Number of perturbation groups: {len(pert_data.filter_perturbation_list)}")
Step 2: Setting control barcodes...
  0%|          | 0/678401 [00:00<?, ?it/s]100%|██████████| 678401/678401 [00:16<00:00, 40795.08it/s]
100%|██████████| 678401/678401 [01:18<00:00, 8594.04it/s]
Step 3: Data splitting...
========== data split finished!
========== data split finished!
========== data split finished!
Step 8: Getting differential genes...
100%|██████████| 128725/128725 [00:18<00:00, 6873.53it/s]
========== get de genes finished!
Step 9: Saving processed data...
Pert_Data object generation completed!
Final dataset shape: (678401, 978)
Number of perturbation groups: 128725
[ ]: