Create chemical perturabation data
This notebook is used to create perturbation data for X-Pert training
1. import the necessary libraries
[9]:
import xpert as xp
from xpert.data.utils import get_info_txt
import scanpy as sc
import numpy as np
from types import MethodType
from tqdm import tqdm
import os
import pickle
import warnings
import logging
warnings.filterwarnings('ignore')
sc.settings.verbosity = 0
logging.getLogger('scanpy').setLevel(logging.ERROR)
logging.getLogger('anndata').setLevel(logging.ERROR)
2. load original adata file of perturbation data
[4]:
adata = sc.read('../../data/L1000_phase1/adata_L1000_phase1.h5ad')
adata
[4]:
AnnData object with n_obs × n_vars = 678401 × 978
obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id', 'canonical_smiles', 'plate_id'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
add obs.column with ‘perturbation_new’ and ‘cell_type_new’, which is used to perform data proprecessing.
[5]:
def get_perturbation_group(x):
if x['pert_iname'] == 'DMSO':
return ' | '.join(['control', x['dose'], x['cell_id']])
return ' | '.join([x['pert_iname'], x['dose'], x['cell_id']])
def get_perturbation_new(x):
if x['pert_iname'] == 'DMSO':
return 'control'
else:
return x['pert_iname']
def get_cell_type_new(x):
return x['cell_id']
adata.obs['dose'] = adata.obs['pert_dose']
perturbation_group = adata.obs.apply(get_perturbation_group, axis=1)
adata.obs['perturbation_group'] = perturbation_group
perturbation_new = adata.obs.apply(get_perturbation_new, axis=1)
adata.obs['perturbation_new'] = perturbation_new
celltype_new = adata.obs.apply(get_cell_type_new, axis=1)
adata.obs['celltype_new'] = celltype_new
[ ]:
[ ]:
3. generate Pert_Data object
[14]:
# parameters
pert_cell_filter = 0 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.8, 0.2, 0] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
data_dir = '../../data/L1000_phase1/'
[15]:
# Create Pert_Data object
pert_data = xp.data.Byte_Pert_Data(
prefix='L1000_phase1',
pert_cell_filter=pert_cell_filter,
seed=seed,
split_ratio=split_ratio,
split_type=split_type,
var_num=var_num,
num_de_genes=num_de_genes,
bs_train=bs_train,
bs_test=bs_test
)
# Complete data processing pipeline
print("Step 1: Reading files...")
pert_data.read_files(adata)
# Get the filter_perturbation_list
pert_data.adata_split = pert_data.adata_ori
tmp_obs = pert_data.adata_split[pert_data.adata_split.obs['perturbation_new']!='control'].obs
pert_data.filter_perturbation_list = list(tmp_obs['perturbation_group'].unique()) # record the perturbation pair
Step 1: Reading files...
========== read file finished!
4. Rewrite cell pair function
[16]:
def set_control_barcode(self):
"""
this function is used to set control_barcode for each pert
"""
# - get pert_index_dict
obs_df = pert_data.adata_split.obs
obs_df['control_barcode'] = 'None'
pert_index_dict = {}
for i in tqdm(range(len(obs_df))):
palte_id, cell_type, pert_time, perturbation = obs_df.loc[str(i), 'plate_id'], obs_df.loc[str(i), 'celltype_new'], obs_df.loc[str(i), 'pert_time'], obs_df.loc[str(i), 'perturbation_new']
if perturbation == 'control':
_key = ' | '.join([palte_id, cell_type, str(pert_time)])
if _key in pert_index_dict:
pert_index_dict[_key].append(str(i))
else:
pert_index_dict[_key] = []
pert_index_dict[_key].append(str(i))
i += 1
# - set control_barcode
np.random.seed(pert_data.seed)
for i in tqdm(range(len(obs_df))):
palte_id, cell_type, pert_time, perturbation = obs_df.loc[str(i), 'plate_id'], obs_df.loc[str(i), 'celltype_new'], obs_df.loc[str(i), 'pert_time'], obs_df.loc[str(i), 'perturbation_new']
if perturbation != 'control':
_key = ' | '.join([palte_id, cell_type, str(pert_time)])
pair_control_obs = np.random.choice(pert_index_dict[_key], 1, replace=True)
obs_df.loc[str(i),'control_barcode'] = pair_control_obs[0]
pert_data.adata_split.obs = obs_df
def get_de_genes(self,
rankby_abs = True,
key_added = 'rank_genes_groups'):
gene_dict = {}
pvals_dict, pvals_adj_dict, scores_dict, logfoldchanges_dict = {}, {}, {}, {}
for pert in tqdm(self.filter_perturbation_list):
gene_dict[pert] = list(self.adata_split.var_names)
pvals_dict[pert] = [0.1]*len(self.adata_split.var_names)
pvals_adj_dict[pert] = [0.1]*len(self.adata_split.var_names)
scores_dict[pert] = [0.1]*len(self.adata_split.var_names)
logfoldchanges_dict[pert] = [0.1]*len(self.adata_split.var_names)
self.adata_split.uns[key_added] = gene_dict
self.adata_split.uns['pvals'] = pvals_dict
self.adata_split.uns['pvals_adj'] = pvals_adj_dict
self.adata_split.uns['scores'] = scores_dict
self.adata_split.uns['logfoldchanges'] = logfoldchanges_dict
print('='*10,f'get de genes finished!')
pert_data.set_control_barcode = MethodType(set_control_barcode, pert_data)
[ ]:
5. Run process
[17]:
# print("Step 2: Filtering perturbations...")
# pert_data.filter_perturbation()
# print("Step 3: Preprocessing adata and selecting HVGs...")
# pert_data.get_and_process_adata(var_num=pert_data.var_num)
print("Step 2: Setting control barcodes...")
pert_data.set_control_barcode()
# print("Step 5: Calculating E-distances...")
# pert_data.get_edis_2()
# print("Step 6: Filtering sgRNAs...")
# pert_data.adata_split.obs['sgRNA_new'] = 'control'
# pert_data.filter_sgRNA()
print("Step 3: Data splitting...")
pert_data.data_split_2(split_type = 0,
test_perts = None)
pert_data.data_split_2(split_type = 1,
test_perts = None)
pert_data.data_split_2(split_type = 2,
test_perts = None)
print("Step 8: Getting differential genes...")
get_de_genes(pert_data)
print("Step 9: Saving processed data...")
# Save the processed data
pickle.dump(pert_data, open(os.path.join(data_dir, 'pert_data.pkl'), 'wb'))
print("Pert_Data object generation completed!")
print(f"Final dataset shape: {pert_data.adata_split.shape}")
print(f"Number of perturbation groups: {len(pert_data.filter_perturbation_list)}")
Step 2: Setting control barcodes...
0%| | 0/678401 [00:00<?, ?it/s]100%|██████████| 678401/678401 [00:16<00:00, 40795.08it/s]
100%|██████████| 678401/678401 [01:18<00:00, 8594.04it/s]
Step 3: Data splitting...
========== data split finished!
========== data split finished!
========== data split finished!
Step 8: Getting differential genes...
100%|██████████| 128725/128725 [00:18<00:00, 6873.53it/s]
========== get de genes finished!
Step 9: Saving processed data...
Pert_Data object generation completed!
Final dataset shape: (678401, 978)
Number of perturbation groups: 128725
[ ]: