{ "cells": [ { "cell_type": "markdown", "id": "00a28b50", "metadata": {}, "source": [ "# Create chemical perturabation data\n", "This notebook is used to create perturbation data for X-Pert training" ] }, { "cell_type": "markdown", "id": "836badb3", "metadata": {}, "source": [ "## 1. import the necessary libraries" ] }, { "cell_type": "code", "execution_count": 9, "id": "a086ed9e", "metadata": {}, "outputs": [], "source": [ "import xpert as xp\n", "from xpert.data.utils import get_info_txt\n", "import scanpy as sc\n", "import numpy as np\n", "from types import MethodType\n", "from tqdm import tqdm\n", "import os\n", "import pickle\n", "import warnings\n", "import logging\n", "warnings.filterwarnings('ignore')\n", "sc.settings.verbosity = 0\n", "logging.getLogger('scanpy').setLevel(logging.ERROR)\n", "logging.getLogger('anndata').setLevel(logging.ERROR)\n" ] }, { "cell_type": "markdown", "id": "45f2c9e6", "metadata": {}, "source": [ "## 2. load original adata file of perturbation data" ] }, { "cell_type": "code", "execution_count": 4, "id": "81e8b788", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "AnnData object with n_obs × n_vars = 678401 × 978\n", " obs: 'inst_id', 'rna_plate', 'rna_well', 'pert_id', 'pert_iname', 'pert_type', 'pert_dose', 'pert_dose_unit', 'pert_time', 'pert_time_unit', 'cell_id', 'canonical_smiles', 'plate_id'\n", " var: 'pr_gene_id', 'pr_gene_symbol', 'pr_gene_title', 'pr_is_lm', 'pr_is_bing'" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "adata = sc.read('../../data/L1000_phase1/adata_L1000_phase1.h5ad')\n", "adata" ] }, { "cell_type": "markdown", "id": "61d19fe9", "metadata": {}, "source": [ "add obs.column with 'perturbation_new' and 'cell_type_new', which is used to perform data proprecessing." ] }, { "cell_type": "code", "execution_count": 5, "id": "2441a6fa", "metadata": {}, "outputs": [], "source": [ "def get_perturbation_group(x):\n", " if x['pert_iname'] == 'DMSO':\n", " return ' | '.join(['control', x['dose'], x['cell_id']])\n", " return ' | '.join([x['pert_iname'], x['dose'], x['cell_id']])\n", "\n", "def get_perturbation_new(x):\n", " if x['pert_iname'] == 'DMSO':\n", " return 'control'\n", " else:\n", " return x['pert_iname']\n", " \n", "def get_cell_type_new(x):\n", " return x['cell_id']\n", "\n", "adata.obs['dose'] = adata.obs['pert_dose']\n", "\n", "perturbation_group = adata.obs.apply(get_perturbation_group, axis=1)\n", "adata.obs['perturbation_group'] = perturbation_group\n", "\n", "perturbation_new = adata.obs.apply(get_perturbation_new, axis=1)\n", "adata.obs['perturbation_new'] = perturbation_new\n", "\n", "celltype_new = adata.obs.apply(get_cell_type_new, axis=1)\n", "adata.obs['celltype_new'] = celltype_new\n", "\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "d127e4f7", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "4be019de", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "8b741f37", "metadata": {}, "source": [ "## 3. generate Pert_Data object\n", "\n" ] }, { "cell_type": "code", "execution_count": 14, "id": "189334c2", "metadata": {}, "outputs": [], "source": [ "# parameters\n", "\n", "pert_cell_filter = 0 # this is used to filter perts, cell number less than this will be filtered\n", "seed = 2024 # this is the random seed\n", "split_type = 1 # 1 for unseen perts; 0 for unseen celltypes\n", "split_ratio = [0.8, 0.2, 0] # train:test:val; val is used to choose data, test is for final validation\n", "var_num = 5000 # selecting hvg number\n", "num_de_genes = 20 # number of de genes\n", "bs_train = 32 # batch size of trainloader\n", "bs_test = 32 # batch size of testloader\n", "data_dir = '../../data/L1000_phase1/'" ] }, { "cell_type": "code", "execution_count": 15, "id": "237fa71a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 1: Reading files...\n", "========== read file finished!\n" ] } ], "source": [ "# Create Pert_Data object\n", "pert_data = xp.data.Byte_Pert_Data(\n", " prefix='L1000_phase1',\n", " pert_cell_filter=pert_cell_filter,\n", " seed=seed,\n", " split_ratio=split_ratio,\n", " split_type=split_type,\n", " var_num=var_num,\n", " num_de_genes=num_de_genes,\n", " bs_train=bs_train,\n", " bs_test=bs_test\n", ")\n", "\n", "# Complete data processing pipeline\n", "print(\"Step 1: Reading files...\")\n", "pert_data.read_files(adata)\n", "\n", "# Get the filter_perturbation_list\n", "pert_data.adata_split = pert_data.adata_ori\n", "tmp_obs = pert_data.adata_split[pert_data.adata_split.obs['perturbation_new']!='control'].obs\n", "pert_data.filter_perturbation_list = list(tmp_obs['perturbation_group'].unique()) # record the perturbation pair\n", "\n" ] }, { "cell_type": "markdown", "id": "e118f41f", "metadata": {}, "source": [ "## 4. Rewrite cell pair function " ] }, { "cell_type": "code", "execution_count": 16, "id": "5bb87828", "metadata": {}, "outputs": [], "source": [ "def set_control_barcode(self):\n", " \"\"\"\n", " this function is used to set control_barcode for each pert\n", " \"\"\"\n", "\n", " # - get pert_index_dict\n", " obs_df = pert_data.adata_split.obs\n", " obs_df['control_barcode'] = 'None'\n", "\n", " pert_index_dict = {}\n", "\n", " for i in tqdm(range(len(obs_df))):\n", " palte_id, cell_type, pert_time, perturbation = obs_df.loc[str(i), 'plate_id'], obs_df.loc[str(i), 'celltype_new'], obs_df.loc[str(i), 'pert_time'], obs_df.loc[str(i), 'perturbation_new']\n", " if perturbation == 'control':\n", " _key = ' | '.join([palte_id, cell_type, str(pert_time)])\n", " if _key in pert_index_dict:\n", " pert_index_dict[_key].append(str(i))\n", " else:\n", " pert_index_dict[_key] = []\n", " pert_index_dict[_key].append(str(i))\n", " i += 1\n", " \n", " # - set control_barcode\n", " np.random.seed(pert_data.seed)\n", " for i in tqdm(range(len(obs_df))):\n", " palte_id, cell_type, pert_time, perturbation = obs_df.loc[str(i), 'plate_id'], obs_df.loc[str(i), 'celltype_new'], obs_df.loc[str(i), 'pert_time'], obs_df.loc[str(i), 'perturbation_new']\n", " if perturbation != 'control':\n", " _key = ' | '.join([palte_id, cell_type, str(pert_time)])\n", " pair_control_obs = np.random.choice(pert_index_dict[_key], 1, replace=True)\n", " obs_df.loc[str(i),'control_barcode'] = pair_control_obs[0]\n", "\n", " pert_data.adata_split.obs = obs_df\n", "\n", "def get_de_genes(self,\n", " rankby_abs = True,\n", " key_added = 'rank_genes_groups'):\n", " gene_dict = {}\n", " pvals_dict, pvals_adj_dict, scores_dict, logfoldchanges_dict = {}, {}, {}, {}\n", "\n", " for pert in tqdm(self.filter_perturbation_list):\n", " \n", "\n", " gene_dict[pert] = list(self.adata_split.var_names)\n", " pvals_dict[pert] = [0.1]*len(self.adata_split.var_names)\n", " pvals_adj_dict[pert] = [0.1]*len(self.adata_split.var_names)\n", " scores_dict[pert] = [0.1]*len(self.adata_split.var_names)\n", " logfoldchanges_dict[pert] = [0.1]*len(self.adata_split.var_names)\n", " \n", " self.adata_split.uns[key_added] = gene_dict\n", " self.adata_split.uns['pvals'] = pvals_dict\n", " self.adata_split.uns['pvals_adj'] = pvals_adj_dict\n", " self.adata_split.uns['scores'] = scores_dict\n", " self.adata_split.uns['logfoldchanges'] = logfoldchanges_dict\n", " print('='*10,f'get de genes finished!')\n", "\n", "pert_data.set_control_barcode = MethodType(set_control_barcode, pert_data)" ] }, { "cell_type": "code", "execution_count": null, "id": "14de6610", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "a86b80c5", "metadata": {}, "source": [ "## 5. Run process" ] }, { "cell_type": "code", "execution_count": 17, "id": "180b2dff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Step 2: Setting control barcodes...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/678401 [00:00