Create cotrain perturabation data
1. import the necessary libraries
[1]:
import xpert as xp
from xpert.data.utils import get_info_txt
import scanpy as sc
import numpy as np
import pandas as pd
from types import MethodType
from tqdm import tqdm
import os
import pickle
import warnings
import logging
warnings.filterwarnings('ignore')
sc.settings.verbosity = 0
logging.getLogger('scanpy').setLevel(logging.ERROR)
logging.getLogger('anndata').setLevel(logging.ERROR)
2. load adata and get specific cell linesm
[2]:
cell_line = 'HT29'
[3]:
adata_sig = sc.read('../../data/L1000_phase1_cotrain/adata_sig.h5ad')
adata_sig
[3]:
AnnData object with n_obs × n_vars = 473647 × 978
obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[4]:
adata_pert = adata_sig[((adata_sig.obs['pert_type']=='trt_sh.cgs')|(adata_sig.obs['pert_type']=='trt_cp'))&(adata_sig.obs['cell_id']==cell_line)].copy()
adata_pert
[4]:
AnnData object with n_obs × n_vars = 17815 × 978
obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[5]:
adata_pert.obs['pert_type'].value_counts()
[5]:
trt_cp 14513
trt_sh.cgs 3302
Name: pert_type, dtype: int64
add smiles
[6]:
pert_id_unique = pd.Series(np.unique(adata_pert.obs.pert_id))
print(f"# of unique perturbations: {len(pert_id_unique)}")
# of unique perturbations: 14215
[7]:
import pathlib
reference_df = pd.read_csv('../../data/L1000_phase1_cotrain/GSE92742_Broad_LINCS_pert_info.txt.gz', delimiter = "\t")
[8]:
reference_df = reference_df.loc[reference_df.pert_id.isin(pert_id_unique), ['pert_id', 'canonical_smiles']]
reference_df.canonical_smiles.value_counts()
[8]:
-666 3338
restricted 14
CS(=O)(=O)CCNCc1ccc(o1)-c1ccc2ncnc(Nc3ccc(OCc4cccc(F)c4)c(Cl)c3)c2c1 2
CCOC(=O)C1=C(NC(=C(C1C2=CC=CC=C2Cl)C(=O)OC)C)COCCN 2
CO[C@H]1\C=C\O[C@@]2(C)Oc3c(C2=O)c2c(O)c(\C=N\N4CCN(C)CC4)c(NC(=O)\C(C)=C/C=C/[C@H](C)[C@H](O)[C@@H](C)[C@@H](O)[C@@H](C)[C@H](OC(C)=O)[C@@H]1C)c(O)c2c(O)c3C 2
...
COc1ccc2c3c([C@@H](CO)N(C[C@]33CCN(CC3)C(=O)C3CCOCC3)C(C)=O)n(C)c2c1 1
CN(C)S(=O)(=O)c1ccc(N2CCCC2)c(c1)C(=O)N1CCN(CC1)c1ccccc1 1
NC(=O)c1sc(cc1N)-c1ccsc1 1
CC(=CCCN1CCC(CC1)NC(=O)[C@@](O)(C2CCCC2)c3ccccc3)C 1
OC[C@@H]1[C@H]2Cn3c(ccc(C4=CCCCC4)c3=O)[C@@H]([C@H]1C(O)=O)N2C(=O)Cc1ccccn1 1
Name: canonical_smiles, Length: 10858, dtype: int64
[9]:
adata_pert.obs = pd.merge(adata_pert.obs, reference_df, on='pert_id', how='left')
adata_pert.obs_names = adata_pert.obs_names.astype(str)
[10]:
adata_drug = adata_pert[adata_pert.obs['pert_type']=='trt_cp'].copy()
adata_gene = adata_pert[adata_pert.obs['pert_type']=='trt_sh.cgs'].copy()
adata_drug, adata_gene
[10]:
(AnnData object with n_obs × n_vars = 14513 × 978
obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing',
AnnData object with n_obs × n_vars = 3302 × 978
obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing')
[11]:
# remove invalid smiles
adata_drug.obs.loc[:, 'canonical_smiles'] = adata_drug.obs.canonical_smiles.astype('str')
invalid_smiles = adata_drug.obs.canonical_smiles.isin(['-666',
'restricted',
'nan'
])
# cond = adata_drug.obs['pert_type']=='trt_sh.cgs'
print(f'Among {len(adata_drug)} observations, {100*invalid_smiles.sum()/len(adata_drug):.2f}% ({invalid_smiles.sum()}) have an invalid SMILES string')
adata_drug = adata_drug[(~invalid_smiles)].copy()
adata_drug
Among 14513 observations, 0.76% (110) have an invalid SMILES string
[11]:
AnnData object with n_obs × n_vars = 14403 × 978
obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[12]:
# - contruct pert_embed to delete same embeddings
pert_embed_dict = {}
pert_embed = pd.read_csv('../../data/L1000_phase1_cotrain/embed_ecfp.csv', sep = ",", index_col=0)
for pert in pert_embed.columns:
if pert in pert_embed.columns:
pert_embed_dict[pert] = pert_embed.loc[:, pert].values
else:
print(f'{pert} not in pert_embed')
pert_embed_dict[pert] = pert_embed.loc[:, np.random.choice(pert_embed.columns, 1)[0]].values
# - create drug to embeddings
embed_drugs = np.unique(adata_drug.obs['pert_iname'])
drug_embedding_dict = {}
for i, drug in enumerate(embed_drugs):
drug_embedding_dict[drug] = pert_embed_dict[drug]
from collections import defaultdict
# 将相同的向量聚在一起
embedding_to_drugs = defaultdict(list)
for drug, emb in drug_embedding_dict.items():
emb_key = tuple(emb.tolist()) # 把 ndarray 转成 hashable 的 tuple
embedding_to_drugs[emb_key].append(drug)
len(embedding_to_drugs)
unique_drugs = [value[0] for key, value in embedding_to_drugs.items()]
len(unique_drugs)
adata_drug = adata_drug[adata_drug.obs['pert_iname'].isin(unique_drugs)].copy()
adata_drug
# remove dulplicated smiles
dup_mask = adata_drug.obs['canonical_smiles'].duplicated(keep='first') # True == 要删
adata_drug = adata_drug[~dup_mask, :].copy()
adata_drug
[12]:
AnnData object with n_obs × n_vars = 6353 × 978
obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
3. Get control cells
[13]:
adata_level3 = sc.read('../../data/L1000_phase1_cotrain/adata_inst.h5ad')
adata_level3
[13]:
AnnData object with n_obs × n_vars = 1319138 × 978
obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[14]:
cond_1 = adata_level3.obs['cell_id'] == cell_line
cond_2 = adata_level3.obs['pert_type'] == 'ctl_untrt'
adata_level3_part = adata_level3[cond_1&cond_2]
adata_level3_part
[14]:
View of AnnData object with n_obs × n_vars = 1917 × 978
obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'
var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[15]:
import anndata as ad
adata_level3_ctrl = ad.AnnData(X=np.mean(adata_level3_part.X, axis=0).reshape(1, -1),
obs = pd.DataFrame(adata_level3_part.obs.iloc[0,:]).T,
var = pd.DataFrame(index=adata_level3_part.var_names))
adata_level3_ctrl
[15]:
AnnData object with n_obs × n_vars = 1 × 978
obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'
[16]:
adata_concat = ad.concat([adata_drug, adata_gene, adata_level3_ctrl], join = 'outer')
adata_concat
[16]:
AnnData object with n_obs × n_vars = 9656 × 978
obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles', 'inst_id', 'rna_plate', 'rna_well'
4. Add basic attributes
[17]:
adata = adata_concat.copy()
[18]:
def get_perturbation_group(x):
if x['pert_iname'] == 'UnTrt':
return ' | '.join(['control', x['dose'], x['cell_id']])
return ' | '.join([x['pert_iname'], x['dose'], x['cell_id']])
def get_perturbation_new(x):
if x['pert_iname'] == 'UnTrt':
return 'control'
else:
return x['pert_iname']
def get_cell_type_new(x):
return x['cell_id']
adata.obs['dose'] = adata.obs['pert_dose']
perturbation_group = adata.obs.apply(get_perturbation_group, axis=1)
adata.obs['perturbation_group'] = perturbation_group
perturbation_new = adata.obs.apply(get_perturbation_new, axis=1)
adata.obs['perturbation_new'] = perturbation_new
celltype_new = adata.obs.apply(get_cell_type_new, axis=1)
adata.obs['celltype_new'] = celltype_new
[ ]:
[ ]:
5. generate Pert_Data object
[19]:
# parameters
pert_cell_filter = 0 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.8, 0.2, 0] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
data_dir = '../../data/L1000_phase1_cotrain/'
[20]:
# Create Pert_Data object
pert_data = xp.data.Byte_Pert_Data(
prefix='L1000_phase1',
pert_cell_filter=pert_cell_filter,
seed=seed,
split_ratio=split_ratio,
split_type=split_type,
var_num=var_num,
num_de_genes=num_de_genes,
bs_train=bs_train,
bs_test=bs_test
)
# Complete data processing pipeline
print("Step 1: Reading files...")
pert_data.read_files(adata)
# Get the filter_perturbation_list
pert_data.adata_split = pert_data.adata_ori
tmp_obs = pert_data.adata_split[pert_data.adata_split.obs['perturbation_new']!='control'].obs
pert_data.filter_perturbation_list = list(tmp_obs['perturbation_group'].unique()) # record the perturbation pair
Step 1: Reading files...
========== read file finished!
6. Rewrite cell pair function
[21]:
def set_control_barcode(self):
"""
this function is used to set control_barcode for each pert
"""
self.obs_df_split = self.adata_split.obs.copy()
# - set all control_barcode to None
self.obs_df_split['control_barcode'] = 'None'
# # - get all the control barcodes
# control_obs = np.array(self.obs_df_split[(self.obs_df_split['perturbation_new']=='control')].index)
np.random.seed(self.seed)
for pert in tqdm(self.filter_perturbation_list):
# - get the pert control
# - get all the control barcodes
control_obs = np.array(self.obs_df_split[(self.obs_df_split['perturbation_group']==' | '.join(['control', '-666.0', pert.split(' | ')[-1]]))].index)
obs_df_sub_idx = np.array(self.obs_df_split[self.obs_df_split['perturbation_group']==pert].index)
# - get the paired control
pair_control_obs = np.random.choice(control_obs, len(obs_df_sub_idx), replace=True)
# - set the control barcode
self.obs_df_split.loc[obs_df_sub_idx,'control_barcode'] = pair_control_obs
self.adata_split.obs = self.obs_df_split
print('='*10,f'set control barcodes finished!')
def get_de_genes(self,
rankby_abs = True,
key_added = 'rank_genes_groups'):
gene_dict = {}
pvals_dict, pvals_adj_dict, scores_dict, logfoldchanges_dict = {}, {}, {}, {}
for pert in tqdm(self.filter_perturbation_list):
gene_dict[pert] = list(self.adata_split.var_names)
pvals_dict[pert] = [0.1]*len(self.adata_split.var_names)
pvals_adj_dict[pert] = [0.1]*len(self.adata_split.var_names)
scores_dict[pert] = [0.1]*len(self.adata_split.var_names)
logfoldchanges_dict[pert] = [0.1]*len(self.adata_split.var_names)
self.adata_split.uns[key_added] = gene_dict
self.adata_split.uns['pvals'] = pvals_dict
self.adata_split.uns['pvals_adj'] = pvals_adj_dict
self.adata_split.uns['scores'] = scores_dict
self.adata_split.uns['logfoldchanges'] = logfoldchanges_dict
print('='*10,f'get de genes finished!')
pert_data.set_control_barcode = MethodType(set_control_barcode, pert_data)
[ ]:
7. Run process
[22]:
# print("Step 2: Filtering perturbations...")
# pert_data.filter_perturbation()
# print("Step 3: Preprocessing adata and selecting HVGs...")
# pert_data.get_and_process_adata(var_num=pert_data.var_num)
print("Step 2: Setting control barcodes...")
pert_data.set_control_barcode()
# print("Step 5: Calculating E-distances...")
# pert_data.get_edis_2()
# print("Step 6: Filtering sgRNAs...")
# pert_data.adata_split.obs['sgRNA_new'] = 'control'
# pert_data.filter_sgRNA()
print("Step 3: Data splitting...")
# pert_data.data_split_2(split_type = 0,
# test_perts = None)
pert_data.split_ratio = [0.9, 0.1, 0]
pert_data.data_split_2(split_type = 1,
test_perts = None)
# pert_data.data_split_2(split_type = 2,
# test_perts = None)
print("Step 4: Getting differential genes...")
get_de_genes(pert_data)
print("Step 5: Saving processed data...")
# Save the processed data
pickle.dump(pert_data, open(os.path.join(data_dir, 'pert_data.pkl'), 'wb'))
print("Pert_Data object generation completed!")
print(f"Final dataset shape: {pert_data.adata_split.shape}")
print(f"Number of perturbation groups: {len(pert_data.filter_perturbation_list)}")
Step 2: Setting control barcodes...
100%|██████████| 9405/9405 [00:15<00:00, 600.49it/s]
========== set control barcodes finished!
Step 3: Data splitting...
========== data split finished!
Step 4: Getting differential genes...
100%|██████████| 9405/9405 [00:00<00:00, 11605.10it/s]
========== get de genes finished!
Step 5: Saving processed data...
Pert_Data object generation completed!
Final dataset shape: (9656, 978)
Number of perturbation groups: 9405
[ ]: