Create cotrain perturabation data

1. import the necessary libraries

[1]:
import xpert as xp
from xpert.data.utils import get_info_txt
import scanpy as sc
import numpy as np
import pandas as pd
from types import MethodType
from tqdm import tqdm
import os
import pickle
import warnings
import logging
warnings.filterwarnings('ignore')
sc.settings.verbosity = 0
logging.getLogger('scanpy').setLevel(logging.ERROR)
logging.getLogger('anndata').setLevel(logging.ERROR)

2. load adata and get specific cell linesm

[2]:
cell_line = 'HT29'
[3]:
adata_sig = sc.read('../../data/L1000_phase1_cotrain/adata_sig.h5ad')
adata_sig
[3]:
AnnData object with n_obs × n_vars = 473647 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[4]:
adata_pert = adata_sig[((adata_sig.obs['pert_type']=='trt_sh.cgs')|(adata_sig.obs['pert_type']=='trt_cp'))&(adata_sig.obs['cell_id']==cell_line)].copy()
adata_pert
[4]:
AnnData object with n_obs × n_vars = 17815 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[5]:
adata_pert.obs['pert_type'].value_counts()
[5]:
trt_cp        14513
trt_sh.cgs     3302
Name: pert_type, dtype: int64

add smiles

[6]:
pert_id_unique = pd.Series(np.unique(adata_pert.obs.pert_id))
print(f"# of unique perturbations: {len(pert_id_unique)}")
# of unique perturbations: 14215
[7]:
import pathlib
reference_df = pd.read_csv('../../data/L1000_phase1_cotrain/GSE92742_Broad_LINCS_pert_info.txt.gz', delimiter = "\t")
[8]:
reference_df = reference_df.loc[reference_df.pert_id.isin(pert_id_unique), ['pert_id', 'canonical_smiles']]
reference_df.canonical_smiles.value_counts()
[8]:
-666                                                                                                                                                             3338
restricted                                                                                                                                                         14
CS(=O)(=O)CCNCc1ccc(o1)-c1ccc2ncnc(Nc3ccc(OCc4cccc(F)c4)c(Cl)c3)c2c1                                                                                                2
CCOC(=O)C1=C(NC(=C(C1C2=CC=CC=C2Cl)C(=O)OC)C)COCCN                                                                                                                  2
CO[C@H]1\C=C\O[C@@]2(C)Oc3c(C2=O)c2c(O)c(\C=N\N4CCN(C)CC4)c(NC(=O)\C(C)=C/C=C/[C@H](C)[C@H](O)[C@@H](C)[C@@H](O)[C@@H](C)[C@H](OC(C)=O)[C@@H]1C)c(O)c2c(O)c3C       2
                                                                                                                                                                 ...
COc1ccc2c3c([C@@H](CO)N(C[C@]33CCN(CC3)C(=O)C3CCOCC3)C(C)=O)n(C)c2c1                                                                                                1
CN(C)S(=O)(=O)c1ccc(N2CCCC2)c(c1)C(=O)N1CCN(CC1)c1ccccc1                                                                                                            1
NC(=O)c1sc(cc1N)-c1ccsc1                                                                                                                                            1
CC(=CCCN1CCC(CC1)NC(=O)[C@@](O)(C2CCCC2)c3ccccc3)C                                                                                                                  1
OC[C@@H]1[C@H]2Cn3c(ccc(C4=CCCCC4)c3=O)[C@@H]([C@H]1C(O)=O)N2C(=O)Cc1ccccn1                                                                                         1
Name: canonical_smiles, Length: 10858, dtype: int64
[9]:
adata_pert.obs = pd.merge(adata_pert.obs, reference_df, on='pert_id', how='left')
adata_pert.obs_names = adata_pert.obs_names.astype(str)
[10]:
adata_drug = adata_pert[adata_pert.obs['pert_type']=='trt_cp'].copy()
adata_gene = adata_pert[adata_pert.obs['pert_type']=='trt_sh.cgs'].copy()
adata_drug, adata_gene
[10]:
(AnnData object with n_obs × n_vars = 14513 × 978
     obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
     var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing',
 AnnData object with n_obs × n_vars = 3302 × 978
     obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
     var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing')
[11]:
# remove invalid smiles
adata_drug.obs.loc[:, 'canonical_smiles'] = adata_drug.obs.canonical_smiles.astype('str')
invalid_smiles = adata_drug.obs.canonical_smiles.isin(['-666',
                                                  'restricted',
                                                  'nan'
                                                  ])
# cond = adata_drug.obs['pert_type']=='trt_sh.cgs'
print(f'Among {len(adata_drug)} observations, {100*invalid_smiles.sum()/len(adata_drug):.2f}% ({invalid_smiles.sum()}) have an invalid SMILES string')
adata_drug = adata_drug[(~invalid_smiles)].copy()
adata_drug
Among 14513 observations, 0.76% (110) have an invalid SMILES string
[11]:
AnnData object with n_obs × n_vars = 14403 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[12]:
# - contruct pert_embed to delete same embeddings
pert_embed_dict = {}
pert_embed = pd.read_csv('../../data/L1000_phase1_cotrain/embed_ecfp.csv', sep = ",", index_col=0)

for pert in pert_embed.columns:
    if pert in pert_embed.columns:
        pert_embed_dict[pert] = pert_embed.loc[:, pert].values
    else:
        print(f'{pert} not in pert_embed')
        pert_embed_dict[pert] = pert_embed.loc[:, np.random.choice(pert_embed.columns, 1)[0]].values


# - create drug to embeddings
embed_drugs = np.unique(adata_drug.obs['pert_iname'])
drug_embedding_dict = {}
for i, drug in enumerate(embed_drugs):
    drug_embedding_dict[drug] = pert_embed_dict[drug]


from collections import defaultdict
# 将相同的向量聚在一起
embedding_to_drugs = defaultdict(list)

for drug, emb in drug_embedding_dict.items():
    emb_key = tuple(emb.tolist())  # 把 ndarray 转成 hashable 的 tuple
    embedding_to_drugs[emb_key].append(drug)

len(embedding_to_drugs)

unique_drugs = [value[0] for key, value in embedding_to_drugs.items()]
len(unique_drugs)
adata_drug = adata_drug[adata_drug.obs['pert_iname'].isin(unique_drugs)].copy()
adata_drug
# remove dulplicated smiles
dup_mask = adata_drug.obs['canonical_smiles'].duplicated(keep='first')   # True == 要删
adata_drug = adata_drug[~dup_mask, :].copy()
adata_drug
[12]:
AnnData object with n_obs × n_vars = 6353 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'

3. Get control cells

[13]:
adata_level3 = sc.read('../../data/L1000_phase1_cotrain/adata_inst.h5ad')
adata_level3
[13]:
AnnData object with n_obs × n_vars = 1319138 × 978
    obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[14]:
cond_1 = adata_level3.obs['cell_id'] == cell_line
cond_2 = adata_level3.obs['pert_type'] == 'ctl_untrt'
adata_level3_part = adata_level3[cond_1&cond_2]
adata_level3_part
[14]:
View of AnnData object with n_obs × n_vars = 1917 × 978
    obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'
    var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'
[15]:
import anndata as ad
adata_level3_ctrl = ad.AnnData(X=np.mean(adata_level3_part.X, axis=0).reshape(1, -1),
                               obs = pd.DataFrame(adata_level3_part.obs.iloc[0,:]).T,
                               var = pd.DataFrame(index=adata_level3_part.var_names))
adata_level3_ctrl
[15]:
AnnData object with n_obs × n_vars = 1 × 978
    obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'
[16]:
adata_concat = ad.concat([adata_drug, adata_gene, adata_level3_ctrl], join = 'outer')
adata_concat
[16]:
AnnData object with n_obs × n_vars = 9656 × 978
    obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles', 'inst_id', 'rna_plate', 'rna_well'

4. Add basic attributes

[17]:
adata = adata_concat.copy()
[18]:
def get_perturbation_group(x):
    if x['pert_iname'] == 'UnTrt':
        return ' | '.join(['control', x['dose'], x['cell_id']])
    return ' | '.join([x['pert_iname'], x['dose'], x['cell_id']])

def get_perturbation_new(x):
    if x['pert_iname'] == 'UnTrt':
        return 'control'
    else:
        return x['pert_iname']

def get_cell_type_new(x):
    return x['cell_id']

adata.obs['dose'] = adata.obs['pert_dose']

perturbation_group = adata.obs.apply(get_perturbation_group, axis=1)
adata.obs['perturbation_group'] = perturbation_group

perturbation_new = adata.obs.apply(get_perturbation_new, axis=1)
adata.obs['perturbation_new'] = perturbation_new

celltype_new = adata.obs.apply(get_cell_type_new, axis=1)
adata.obs['celltype_new'] = celltype_new
[ ]:

[ ]:

5. generate Pert_Data object

[19]:
# parameters

pert_cell_filter = 0 # this is used to filter perts, cell number less than this will be filtered
seed = 2024 # this is the random seed
split_type = 1 # 1 for unseen perts; 0 for unseen celltypes
split_ratio = [0.8, 0.2, 0] # train:test:val; val is used to choose data, test is for final validation
var_num = 5000 # selecting hvg number
num_de_genes = 20 # number of de genes
bs_train = 32 # batch size of trainloader
bs_test = 32 # batch size of testloader
data_dir = '../../data/L1000_phase1_cotrain/'
[20]:
# Create Pert_Data object
pert_data = xp.data.Byte_Pert_Data(
    prefix='L1000_phase1',
    pert_cell_filter=pert_cell_filter,
    seed=seed,
    split_ratio=split_ratio,
    split_type=split_type,
    var_num=var_num,
    num_de_genes=num_de_genes,
    bs_train=bs_train,
    bs_test=bs_test
)

# Complete data processing pipeline
print("Step 1: Reading files...")
pert_data.read_files(adata)

# Get the filter_perturbation_list
pert_data.adata_split = pert_data.adata_ori
tmp_obs = pert_data.adata_split[pert_data.adata_split.obs['perturbation_new']!='control'].obs
pert_data.filter_perturbation_list = list(tmp_obs['perturbation_group'].unique()) # record the perturbation pair


Step 1: Reading files...
========== read file finished!

6. Rewrite cell pair function

[21]:
def set_control_barcode(self):
    """
    this function is used to set control_barcode for each pert
    """

    self.obs_df_split = self.adata_split.obs.copy()

    # - set all control_barcode to None
    self.obs_df_split['control_barcode'] = 'None'

    # # - get all the control barcodes
    # control_obs = np.array(self.obs_df_split[(self.obs_df_split['perturbation_new']=='control')].index)

    np.random.seed(self.seed)
    for pert in tqdm(self.filter_perturbation_list):
        # - get the pert control
        # - get all the control barcodes
        control_obs = np.array(self.obs_df_split[(self.obs_df_split['perturbation_group']==' | '.join(['control', '-666.0', pert.split(' | ')[-1]]))].index)

        obs_df_sub_idx = np.array(self.obs_df_split[self.obs_df_split['perturbation_group']==pert].index)
        # - get the paired control
        pair_control_obs = np.random.choice(control_obs, len(obs_df_sub_idx), replace=True)
        # - set the control barcode
        self.obs_df_split.loc[obs_df_sub_idx,'control_barcode'] = pair_control_obs

    self.adata_split.obs = self.obs_df_split
    print('='*10,f'set control barcodes finished!')

def get_de_genes(self,
                rankby_abs = True,
                key_added = 'rank_genes_groups'):
    gene_dict = {}
    pvals_dict, pvals_adj_dict, scores_dict, logfoldchanges_dict = {}, {}, {}, {}

    for pert in tqdm(self.filter_perturbation_list):


        gene_dict[pert] = list(self.adata_split.var_names)
        pvals_dict[pert] = [0.1]*len(self.adata_split.var_names)
        pvals_adj_dict[pert] = [0.1]*len(self.adata_split.var_names)
        scores_dict[pert] = [0.1]*len(self.adata_split.var_names)
        logfoldchanges_dict[pert] = [0.1]*len(self.adata_split.var_names)

    self.adata_split.uns[key_added] = gene_dict
    self.adata_split.uns['pvals'] = pvals_dict
    self.adata_split.uns['pvals_adj'] = pvals_adj_dict
    self.adata_split.uns['scores'] = scores_dict
    self.adata_split.uns['logfoldchanges'] = logfoldchanges_dict
    print('='*10,f'get de genes finished!')

pert_data.set_control_barcode = MethodType(set_control_barcode, pert_data)
[ ]:

7. Run process

[22]:


# print("Step 2: Filtering perturbations...") # pert_data.filter_perturbation() # print("Step 3: Preprocessing adata and selecting HVGs...") # pert_data.get_and_process_adata(var_num=pert_data.var_num) print("Step 2: Setting control barcodes...") pert_data.set_control_barcode() # print("Step 5: Calculating E-distances...") # pert_data.get_edis_2() # print("Step 6: Filtering sgRNAs...") # pert_data.adata_split.obs['sgRNA_new'] = 'control' # pert_data.filter_sgRNA() print("Step 3: Data splitting...") # pert_data.data_split_2(split_type = 0, # test_perts = None) pert_data.split_ratio = [0.9, 0.1, 0] pert_data.data_split_2(split_type = 1, test_perts = None) # pert_data.data_split_2(split_type = 2, # test_perts = None) print("Step 4: Getting differential genes...") get_de_genes(pert_data) print("Step 5: Saving processed data...") # Save the processed data pickle.dump(pert_data, open(os.path.join(data_dir, 'pert_data.pkl'), 'wb')) print("Pert_Data object generation completed!") print(f"Final dataset shape: {pert_data.adata_split.shape}") print(f"Number of perturbation groups: {len(pert_data.filter_perturbation_list)}")
Step 2: Setting control barcodes...
100%|██████████| 9405/9405 [00:15<00:00, 600.49it/s]
========== set control barcodes finished!
Step 3: Data splitting...
========== data split finished!
Step 4: Getting differential genes...
100%|██████████| 9405/9405 [00:00<00:00, 11605.10it/s]
========== get de genes finished!
Step 5: Saving processed data...
Pert_Data object generation completed!
Final dataset shape: (9656, 978)
Number of perturbation groups: 9405
[ ]: