{ "cells": [ { "cell_type": "markdown", "id": "00a28b50", "metadata": {}, "source": [ "# Create cotrain perturabation data" ] }, { "cell_type": "markdown", "id": "836badb3", "metadata": {}, "source": [ "## 1. import the necessary libraries" ] }, { "cell_type": "code", "execution_count": 1, "id": "a086ed9e", "metadata": {}, "outputs": [], "source": [ "import xpert as xp\n", "from xpert.data.utils import get_info_txt\n", "import scanpy as sc\n", "import numpy as np\n", "import pandas as pd\n", "from types import MethodType\n", "from tqdm import tqdm\n", "import os\n", "import pickle\n", "import warnings\n", "import logging\n", "warnings.filterwarnings('ignore')\n", "sc.settings.verbosity = 0\n", "logging.getLogger('scanpy').setLevel(logging.ERROR)\n", "logging.getLogger('anndata').setLevel(logging.ERROR)\n" ] }, { "cell_type": "markdown", "id": "45f2c9e6", "metadata": {}, "source": [ "## 2. load adata and get specific cell linesm" ] }, { "cell_type": "code", "execution_count": 2, "id": "0206c76f", "metadata": {}, "outputs": [], "source": [ "cell_line = 'HT29'" ] }, { "cell_type": "code", "execution_count": 3, "id": "81e8b788", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 473647 × 978\n", " obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata_sig = sc.read('../../data/L1000_phase1_cotrain/adata_sig.h5ad')\n", "adata_sig" ] }, { "cell_type": "code", "execution_count": 4, "id": "a7678beb", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 17815 × 978\n", " obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata_pert = adata_sig[((adata_sig.obs['pert_type']=='trt_sh.cgs')|(adata_sig.obs['pert_type']=='trt_cp'))&(adata_sig.obs['cell_id']==cell_line)].copy()\n", "adata_pert" ] }, { "cell_type": "code", "execution_count": 5, "id": "46c68452", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "trt_cp 14513\n", "trt_sh.cgs 3302\n", "Name: pert_type, dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata_pert.obs['pert_type'].value_counts()" ] }, { "cell_type": "markdown", "id": "9634d7d8", "metadata": {}, "source": [ "add smiles" ] }, { "cell_type": "code", "execution_count": 6, "id": "89baa445", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "# of unique perturbations: 14215\n" ] } ], "source": [ "pert_id_unique = pd.Series(np.unique(adata_pert.obs.pert_id))\n", "print(f\"# of unique perturbations: {len(pert_id_unique)}\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "b601aa39", "metadata": {}, "outputs": [], "source": [ "import pathlib\n", "reference_df = pd.read_csv('../../data/L1000_phase1_cotrain/GSE92742_Broad_LINCS_pert_info.txt.gz', delimiter = \"\\t\")" ] }, { "cell_type": "code", "execution_count": 8, "id": "368e43c5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "-666 3338\n", "restricted 14\n", "CS(=O)(=O)CCNCc1ccc(o1)-c1ccc2ncnc(Nc3ccc(OCc4cccc(F)c4)c(Cl)c3)c2c1 2\n", "CCOC(=O)C1=C(NC(=C(C1C2=CC=CC=C2Cl)C(=O)OC)C)COCCN 2\n", "CO[C@H]1\\C=C\\O[C@@]2(C)Oc3c(C2=O)c2c(O)c(\\C=N\\N4CCN(C)CC4)c(NC(=O)\\C(C)=C/C=C/[C@H](C)[C@H](O)[C@@H](C)[C@@H](O)[C@@H](C)[C@H](OC(C)=O)[C@@H]1C)c(O)c2c(O)c3C 2\n", " ... \n", "COc1ccc2c3c([C@@H](CO)N(C[C@]33CCN(CC3)C(=O)C3CCOCC3)C(C)=O)n(C)c2c1 1\n", "CN(C)S(=O)(=O)c1ccc(N2CCCC2)c(c1)C(=O)N1CCN(CC1)c1ccccc1 1\n", "NC(=O)c1sc(cc1N)-c1ccsc1 1\n", "CC(=CCCN1CCC(CC1)NC(=O)[C@@](O)(C2CCCC2)c3ccccc3)C 1\n", "OC[C@@H]1[C@H]2Cn3c(ccc(C4=CCCCC4)c3=O)[C@@H]([C@H]1C(O)=O)N2C(=O)Cc1ccccn1 1\n", "Name: canonical_smiles, Length: 10858, dtype: int64" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reference_df = reference_df.loc[reference_df.pert_id.isin(pert_id_unique), ['pert_id', 'canonical_smiles']]\n", "reference_df.canonical_smiles.value_counts()" ] }, { "cell_type": "code", "execution_count": 9, "id": "65e83cb9", "metadata": {}, "outputs": [], "source": [ "adata_pert.obs = pd.merge(adata_pert.obs, reference_df, on='pert_id', how='left')\n", "adata_pert.obs_names = adata_pert.obs_names.astype(str)" ] }, { "cell_type": "code", "execution_count": 10, "id": "842a85c6", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(AnnData object with n_obs × n_vars = 14513 × 978\n", " obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing',\n", " AnnData object with n_obs × n_vars = 3302 × 978\n", " obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing')" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata_drug = adata_pert[adata_pert.obs['pert_type']=='trt_cp'].copy()\n", "adata_gene = adata_pert[adata_pert.obs['pert_type']=='trt_sh.cgs'].copy()\n", "adata_drug, adata_gene" ] }, { "cell_type": "code", "execution_count": 11, "id": "6bbc7102", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Among 14513 observations, 0.76% (110) have an invalid SMILES string\n" ] }, { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 14403 × 978\n", " obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# remove invalid smiles\n", "adata_drug.obs.loc[:, 'canonical_smiles'] = adata_drug.obs.canonical_smiles.astype('str')\n", "invalid_smiles = adata_drug.obs.canonical_smiles.isin(['-666', \n", " 'restricted', \n", " 'nan'\n", " ])\n", "# cond = adata_drug.obs['pert_type']=='trt_sh.cgs'\n", "print(f'Among {len(adata_drug)} observations, {100*invalid_smiles.sum()/len(adata_drug):.2f}% ({invalid_smiles.sum()}) have an invalid SMILES string')\n", "adata_drug = adata_drug[(~invalid_smiles)].copy()\n", "adata_drug" ] }, { "cell_type": "code", "execution_count": 12, "id": "5a493eb7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 6353 × 978\n", " obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# - contruct pert_embed to delete same embeddings\n", "pert_embed_dict = {}\n", "pert_embed = pd.read_csv('../../data/L1000_phase1_cotrain/embed_ecfp.csv', sep = \",\", index_col=0)\n", "\n", "for pert in pert_embed.columns:\n", " if pert in pert_embed.columns:\n", " pert_embed_dict[pert] = pert_embed.loc[:, pert].values\n", " else:\n", " print(f'{pert} not in pert_embed')\n", " pert_embed_dict[pert] = pert_embed.loc[:, np.random.choice(pert_embed.columns, 1)[0]].values\n", "\n", "\n", "# - create drug to embeddings\n", "embed_drugs = np.unique(adata_drug.obs['pert_iname'])\n", "drug_embedding_dict = {}\n", "for i, drug in enumerate(embed_drugs):\n", " drug_embedding_dict[drug] = pert_embed_dict[drug]\n", "\n", "\n", "from collections import defaultdict\n", "# 将相同的向量聚在一起\n", "embedding_to_drugs = defaultdict(list)\n", "\n", "for drug, emb in drug_embedding_dict.items():\n", " emb_key = tuple(emb.tolist()) # 把 ndarray 转成 hashable 的 tuple\n", " embedding_to_drugs[emb_key].append(drug)\n", "\n", "len(embedding_to_drugs)\n", "\n", "unique_drugs = [value[0] for key, value in embedding_to_drugs.items()]\n", "len(unique_drugs)\n", "adata_drug = adata_drug[adata_drug.obs['pert_iname'].isin(unique_drugs)].copy()\n", "adata_drug\n", "# remove dulplicated smiles\n", "dup_mask = adata_drug.obs['canonical_smiles'].duplicated(keep='first') # True == 要删\n", "adata_drug = adata_drug[~dup_mask, :].copy()\n", "adata_drug" ] }, { "cell_type": "markdown", "id": "27a47778", "metadata": {}, "source": [ "## 3. Get control cells" ] }, { "cell_type": "code", "execution_count": 13, "id": "33ed202c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 1319138 × 978\n", " obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata_level3 = sc.read('../../data/L1000_phase1_cotrain/adata_inst.h5ad')\n", "adata_level3" ] }, { "cell_type": "code", "execution_count": 14, "id": "45098999", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "View of AnnData object with n_obs × n_vars = 1917 × 978\n", " obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cond_1 = adata_level3.obs['cell_id'] == cell_line\n", "cond_2 = adata_level3.obs['pert_type'] == 'ctl_untrt'\n", "adata_level3_part = adata_level3[cond_1&cond_2]\n", "adata_level3_part" ] }, { "cell_type": "code", "execution_count": 15, "id": "cee5551f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 1 × 978\n", " obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id'" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import anndata as ad\n", "adata_level3_ctrl = ad.AnnData(X=np.mean(adata_level3_part.X, axis=0).reshape(1, -1),\n", " obs = pd.DataFrame(adata_level3_part.obs.iloc[0,:]).T,\n", " var = pd.DataFrame(index=adata_level3_part.var_names))\n", "adata_level3_ctrl" ] }, { "cell_type": "code", "execution_count": 16, "id": "055dc986", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 9656 × 978\n", " obs: 'sig_id', 'pert_id', 'pert_iname', 'pert_type', 'cell_id', 'pert_dose', 'pert_dose_unit', 'pert_idose', 'pert_time', 'pert_time_unit', 'pert_itime', 'distil_id', 'canonical_smiles', 'inst_id', 'rna_plate', 'rna_well'" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata_concat = ad.concat([adata_drug, adata_gene, adata_level3_ctrl], join = 'outer')\n", "adata_concat" ] }, { "cell_type": "markdown", "id": "3a368827", "metadata": {}, "source": [ "## 4. Add basic attributes" ] }, { "cell_type": "code", "execution_count": 17, "id": "3fe6389c", "metadata": {}, "outputs": [], "source": [ "adata = adata_concat.copy()" ] }, { "cell_type": "code", "execution_count": 18, "id": "2441a6fa", "metadata": {}, "outputs": [], "source": [ "def get_perturbation_group(x):\n", " if x['pert_iname'] == 'UnTrt':\n", " return ' | '.join(['control', x['dose'], x['cell_id']])\n", " return ' | '.join([x['pert_iname'], x['dose'], x['cell_id']])\n", "\n", "def get_perturbation_new(x):\n", " if x['pert_iname'] == 'UnTrt':\n", " return 'control'\n", " else:\n", " return x['pert_iname']\n", " \n", "def get_cell_type_new(x):\n", " return x['cell_id']\n", "\n", "adata.obs['dose'] = adata.obs['pert_dose']\n", "\n", "perturbation_group = adata.obs.apply(get_perturbation_group, axis=1)\n", "adata.obs['perturbation_group'] = perturbation_group\n", "\n", "perturbation_new = adata.obs.apply(get_perturbation_new, axis=1)\n", "adata.obs['perturbation_new'] = perturbation_new\n", "\n", "celltype_new = adata.obs.apply(get_cell_type_new, axis=1)\n", "adata.obs['celltype_new'] = celltype_new\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "d127e4f7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4be019de", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "8b741f37", "metadata": {}, "source": [ "## 5. generate Pert_Data object\n", "\n" ] }, { "cell_type": "code", "execution_count": 19, "id": "189334c2", "metadata": {}, "outputs": [], "source": [ "# parameters\n", "\n", "pert_cell_filter = 0 # this is used to filter perts, cell number less than this will be filtered\n", "seed = 2024 # this is the random seed\n", "split_type = 1 # 1 for unseen perts; 0 for unseen celltypes\n", "split_ratio = [0.8, 0.2, 0] # train:test:val; val is used to choose data, test is for final validation\n", "var_num = 5000 # selecting hvg number\n", "num_de_genes = 20 # number of de genes\n", "bs_train = 32 # batch size of trainloader\n", "bs_test = 32 # batch size of testloader\n", "data_dir = '../../data/L1000_phase1_cotrain/'" ] }, { "cell_type": "code", "execution_count": 20, "id": "237fa71a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 1: Reading files...\n", "========== read file finished!\n" ] } ], "source": [ "# Create Pert_Data object\n", "pert_data = xp.data.Byte_Pert_Data(\n", " prefix='L1000_phase1',\n", " pert_cell_filter=pert_cell_filter,\n", " seed=seed,\n", " split_ratio=split_ratio,\n", " split_type=split_type,\n", " var_num=var_num,\n", " num_de_genes=num_de_genes,\n", " bs_train=bs_train,\n", " bs_test=bs_test\n", ")\n", "\n", "# Complete data processing pipeline\n", "print(\"Step 1: Reading files...\")\n", "pert_data.read_files(adata)\n", "\n", "# Get the filter_perturbation_list\n", "pert_data.adata_split = pert_data.adata_ori\n", "tmp_obs = pert_data.adata_split[pert_data.adata_split.obs['perturbation_new']!='control'].obs\n", "pert_data.filter_perturbation_list = list(tmp_obs['perturbation_group'].unique()) # record the perturbation pair\n", "\n" ] }, { "cell_type": "markdown", "id": "e118f41f", "metadata": {}, "source": [ "## 6. Rewrite cell pair function " ] }, { "cell_type": "code", "execution_count": 21, "id": "5bb87828", "metadata": {}, "outputs": [], "source": [ "def set_control_barcode(self):\n", " \"\"\"\n", " this function is used to set control_barcode for each pert\n", " \"\"\"\n", "\n", " self.obs_df_split = self.adata_split.obs.copy()\n", " \n", " # - set all control_barcode to None\n", " self.obs_df_split['control_barcode'] = 'None'\n", " \n", " # # - get all the control barcodes\n", " # control_obs = np.array(self.obs_df_split[(self.obs_df_split['perturbation_new']=='control')].index)\n", " \n", " np.random.seed(self.seed)\n", " for pert in tqdm(self.filter_perturbation_list):\n", " # - get the pert control\n", " # - get all the control barcodes\n", " control_obs = np.array(self.obs_df_split[(self.obs_df_split['perturbation_group']==' | '.join(['control', '-666.0', pert.split(' | ')[-1]]))].index)\n", " \n", " obs_df_sub_idx = np.array(self.obs_df_split[self.obs_df_split['perturbation_group']==pert].index)\n", " # - get the paired control\n", " pair_control_obs = np.random.choice(control_obs, len(obs_df_sub_idx), replace=True)\n", " # - set the control barcode\n", " self.obs_df_split.loc[obs_df_sub_idx,'control_barcode'] = pair_control_obs\n", " \n", " self.adata_split.obs = self.obs_df_split\n", " print('='*10,f'set control barcodes finished!')\n", "\n", "def get_de_genes(self,\n", " rankby_abs = True,\n", " key_added = 'rank_genes_groups'):\n", " gene_dict = {}\n", " pvals_dict, pvals_adj_dict, scores_dict, logfoldchanges_dict = {}, {}, {}, {}\n", "\n", " for pert in tqdm(self.filter_perturbation_list):\n", " \n", "\n", " gene_dict[pert] = list(self.adata_split.var_names)\n", " pvals_dict[pert] = [0.1]*len(self.adata_split.var_names)\n", " pvals_adj_dict[pert] = [0.1]*len(self.adata_split.var_names)\n", " scores_dict[pert] = [0.1]*len(self.adata_split.var_names)\n", " logfoldchanges_dict[pert] = [0.1]*len(self.adata_split.var_names)\n", " \n", " self.adata_split.uns[key_added] = gene_dict\n", " self.adata_split.uns['pvals'] = pvals_dict\n", " self.adata_split.uns['pvals_adj'] = pvals_adj_dict\n", " self.adata_split.uns['scores'] = scores_dict\n", " self.adata_split.uns['logfoldchanges'] = logfoldchanges_dict\n", " print('='*10,f'get de genes finished!')\n", "\n", "pert_data.set_control_barcode = MethodType(set_control_barcode, pert_data)" ] }, { "cell_type": "code", "execution_count": null, "id": "14de6610", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "a86b80c5", "metadata": {}, "source": [ "## 7. Run process" ] }, { "cell_type": "code", "execution_count": 22, "id": "180b2dff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 2: Setting control barcodes...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 9405/9405 [00:15<00:00, 600.49it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "========== set control barcodes finished!\n", "Step 3: Data splitting...\n", "========== data split finished!\n", "Step 4: Getting differential genes...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 9405/9405 [00:00<00:00, 11605.10it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "========== get de genes finished!\n", "Step 5: Saving processed data...\n", "Pert_Data object generation completed!\n", "Final dataset shape: (9656, 978)\n", "Number of perturbation groups: 9405\n" ] } ], "source": [ "\n", "\n", "# print(\"Step 2: Filtering perturbations...\")\n", "# pert_data.filter_perturbation()\n", "\n", "# print(\"Step 3: Preprocessing adata and selecting HVGs...\")\n", "# pert_data.get_and_process_adata(var_num=pert_data.var_num)\n", "\n", "print(\"Step 2: Setting control barcodes...\")\n", "pert_data.set_control_barcode()\n", "\n", "# print(\"Step 5: Calculating E-distances...\")\n", "# pert_data.get_edis_2()\n", "\n", "# print(\"Step 6: Filtering sgRNAs...\")\n", "# pert_data.adata_split.obs['sgRNA_new'] = 'control'\n", "# pert_data.filter_sgRNA()\n", "\n", "print(\"Step 3: Data splitting...\")\n", "# pert_data.data_split_2(split_type = 0,\n", "# test_perts = None)\n", "\n", "pert_data.split_ratio = [0.9, 0.1, 0]\n", "pert_data.data_split_2(split_type = 1,\n", " test_perts = None)\n", "\n", "# pert_data.data_split_2(split_type = 2,\n", "# test_perts = None)\n", "\n", "print(\"Step 4: Getting differential genes...\")\n", "get_de_genes(pert_data)\n", "\n", "print(\"Step 5: Saving processed data...\")\n", "# Save the processed data\n", "pickle.dump(pert_data, open(os.path.join(data_dir, 'pert_data.pkl'), 'wb'))\n", "\n", "print(\"Pert_Data object generation completed!\")\n", "print(f\"Final dataset shape: {pert_data.adata_split.shape}\")\n", "print(f\"Number of perturbation groups: {len(pert_data.filter_perturbation_list)}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a1010889", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "scGPT_2", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }