From cd997185ba959a44746a79068abea576efdafeb1 Mon Sep 17 00:00:00 2001 From: Mi Yang <11358894+Sophon-0@users.noreply.github.com> Date: Thu, 17 Oct 2024 13:53:20 -0400 Subject: [PATCH] prediction of TCR recognition in percentile rank for a given list of epitopes --- ...diction_NewEpitope_NewTCR_PREDICTION.ipynb | 3164 +++++++++++++++++ 1 file changed, 3164 insertions(+) create mode 100644 notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_PREDICTION.ipynb diff --git a/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_PREDICTION.ipynb b/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_PREDICTION.ipynb new file mode 100644 index 0000000..a72a32c --- /dev/null +++ b/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_PREDICTION.ipynb @@ -0,0 +1,3164 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "-NSHxGYaV7vc" + }, + "outputs": [], + "source": [ + "%%capture\n", + "!pip install esm\n", + "\n", + "import os\n", + "from google.colab import drive\n", + "drive.mount('/content/drive')\n", + "path = \"/content/drive/MyDrive/COLAB/TCR_projects\"\n", + "os.chdir(path)" + ] + }, + { + "cell_type": "code", + "source": [ + "from esm.sdk import client\n", + "from getpass import getpass\n", + "\n", + "model_name = \"esm3-small-2024-08\"\n", + "\n", + "# Evolutionary Scale API key Secret: 6LIhzrmbfzsd1rEFMmnZkc\n", + "token = getpass(\"Token from Forge console: \")\n", + "ESM3_model = client(\n", + " model=model_name, # https://forge.evolutionaryscale.ai/console\n", + " url=\"https://forge.evolutionaryscale.ai\",\n", + " token=token,\n", + ")\n" + ], + "metadata": { + "id": "s3ckojV6roA2", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "7f568e22-9db0-494f-e7a8-b79ccbc7793e" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/google/cloud/storage/transfer_manager.py:30: UserWarning: The module `transfer_manager` is a preview feature. Functionality and API may change. This warning will be removed in a future release.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token from Forge console: ··········\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "be5dH_JwZLlu" + }, + "outputs": [], + "source": [ + "import random\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.linear_model import LogisticRegression\n", + "from sklearn.ensemble import RandomForestClassifier\n", + "from sklearn.neural_network import MLPClassifier\n", + "from sklearn.preprocessing import StandardScaler\n", + "\n", + "def preprocess_features(feat, res, train_indices, test_indices):\n", + " x_train = feat.iloc[train_indices, :]\n", + " y_train = res[train_indices]\n", + " x_test = feat.iloc[test_indices, :]\n", + " y_test = res[test_indices]\n", + " # scale the data\n", + " scaler = StandardScaler().fit(x_train)\n", + " x_train = pd.DataFrame(scaler.transform(x_train), index=x_train.index, columns=x_train.columns)\n", + " x_test = pd.DataFrame(scaler.transform(x_test), index=x_test.index, columns=x_test.columns)\n", + " return x_train, y_train, x_test, y_test\n", + "\n", + "\n", + "def run_prediction(mat_test_tab,epitope_embeddings,tcr_embeddings):\n", + "\n", + " def get_embeddings(row):\n", + " epitope = epitope_embeddings.loc[row['epitope']].values\n", + " tcr = tcr_embeddings.loc[row[tcr_features]].values\n", + " return np.concatenate((epitope, tcr))\n", + "\n", + " ################# test set features\n", + " features_test = mat_test_tab.apply(get_embeddings, axis=1)\n", + " features_test = pd.DataFrame(features_test.tolist(), index=features_test.index)\n", + " features_test.index = mat_test_tab[\"epitope\"] + \"_\" + mat_test_tab[tcr_features]\n", + " features_test.columns = epitope_embeddings.columns.tolist() + tcr_embeddings.columns.tolist()\n", + " ## add other information\n", + " df_encoded_TCR_subset = df_encoded_TCR.loc[mat_test_tab[tcr_features], : ]\n", + " df_encoded_epitope_subset = df_encoded_epitope.loc[mat_test_tab[\"epitope\"], : ]\n", + "\n", + " ## combine\n", + " # \"ESM3 + VJ genes\" \"all features\" \"ESMonly\" \"withoutESM\"\n", + " if features_name == \"ESM3 + VJ genes\":\n", + " features_test_all = pd.concat([features_test.reset_index(drop=True), df_encoded_TCR_subset.reset_index(drop=True)], axis=1)\n", + " if features_name == \"all features\":\n", + " features_test_all = pd.concat([features_test.reset_index(drop=True), df_encoded_TCR_subset.reset_index(drop=True), df_encoded_epitope_subset.reset_index(drop=True)], axis=1)\n", + " if features_name == \"ESMonly\":\n", + " features_test_all = features_test\n", + " if features_name == \"withoutESM\":\n", + " features_test_all = pd.concat([df_encoded_TCR_subset.reset_index(drop=True), df_encoded_epitope_subset.reset_index(drop=True)], axis=1)\n", + "\n", + " features_test_all.index = features_test.index\n", + "\n", + " ############################################ run ML ############################################\n", + " X_test = features_test_all\n", + " X_test.columns = X_test.columns.astype(str)\n", + "\n", + " x_test_proba = model.predict_proba(X_test)[:, 1]\n", + "\n", + " p_test = pd.DataFrame(\n", + " { 'split': \"test\",\n", + " 'epitope': mat_test_tab.epitope,\n", + " 'sample': X_test.index,\n", + " 'predicted_prob': x_test_proba\n", + " }\n", + " )\n", + "\n", + " return p_test\n", + "\n", + "\n", + "\n", + "def run_ESM3(seq,model):\n", + " try:\n", + " from esm.models.esm3 import ESM3\n", + " from esm.sdk.api import ESMProtein, SamplingConfig\n", + " from esm.utils.constants.models import ESM3_OPEN_SMALL\n", + "\n", + " # Create an ESMProtein object\n", + " protein = ESMProtein(sequence=seq)\n", + "\n", + " # Encode the protein\n", + " protein_tensor = model.encode(protein)\n", + "\n", + " # Get the embeddings\n", + " output = model.forward_and_sample(\n", + " protein_tensor,\n", + " SamplingConfig(return_per_residue_embeddings=True)\n", + " )\n", + "\n", + " # aggregate the per residue embedding\n", + " df = pd.DataFrame(output.per_residue_embedding)\n", + " column_means = df.mean(axis=0)\n", + " return column_means.transpose(), seq\n", + "\n", + " except Exception as e:\n", + " print(f\"Error processing sequence: {seq}\") # Print the problematic sequence\n", + " print(e) # Print the exception details\n", + " return None # Or handle the error differently\n", + "\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "eB-K1R38U_Y3" + }, + "outputs": [], + "source": [ + "######################################## TCR-Epitope Binding Affinity Prediction Task #################################\n", + "os.chdir(path)\n", + "combined_df = pd.read_csv(\"MixTCRpred/full_training_set_146pmhc.csv\")\n", + "\n", + "# combine cdr3\n", + "combined_df[\"cdr3\"] = combined_df[\"cdr3_TRA\"] + combined_df[\"cdr3_TRB\"]\n", + "combined_df[\"value\"] = 1\n", + "combined_df.index = combined_df[\"epitope\"] + \"_\" + combined_df[\"cdr3\"]\n", + "combined_df\n", + "\n", + "##################################################### choose ESM model #################################################\n", + "\n", + "# \"esm3-small-2024-08\" \"esm2_t6_8M_UR50D\"\n", + "\n", + "model_name = \"esm3-small-2024-08\"\n", + "epitope_embeddings = pd.read_csv('MixTCRpred/data/epitope_embeddings_'+model_name+'.csv',index_col=0)\n", + "cdr3_embeddings = pd.read_csv('MixTCRpred/data/cdr3_embeddings_'+model_name+'.csv',index_col=0)\n", + "\n", + "############################################# subset of available embeddings ###########################################\n", + "combined_df = combined_df.loc[combined_df[\"epitope\"].isin(epitope_embeddings.index) , :]\n", + "combined_df = combined_df.loc[combined_df[\"cdr3\"].isin(cdr3_embeddings.index) , :]\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 493 + }, + "id": "UrSmETAb3b3W", + "outputId": "34cd4ed9-558f-4dcd-86e5-6c6a883300ec" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " TRAV_TCRAV12-1 TRAV_TCRAV17 TRAV_TCRAV19 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n", + "\n", + " TRAV_TCRAV21 TRAV_TCRAV23/DV6 TRAV_TCRAV3 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n", + "\n", + " TRAV_TCRAV38-1 TRAV_TCRAV38-2/DV8 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 \n", + "... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 \n", + "\n", + " TRAV_TCRAV41 TRAV_TRAV-2 ... TRBJ_TRBJ2-5 \\\n", + "cdr3 ... \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 ... 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 ... 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 ... 1.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 ... 1.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 ... 1.0 \n", + "... ... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 ... 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 ... 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 ... 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 ... 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 ... 0.0 \n", + "\n", + " TRBJ_TRBJ2-6 TRBJ_TRBJ2-7 TRBJ_TRBJ2-7 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n", + "\n", + " TRBJ_TRBJ20-1 TRBJ_TRBJ24-1 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 \n", + "... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 \n", + "\n", + " TRBJ_TRBJ38-2/DV8 TRBJ_TRBJ5-1 TRBJ_TRBJ5-6 \\\n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n", + "... ... ... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n", + "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n", + "\n", + " TRBJ_nan \n", + "cdr3 \n", + "CAGGADRLTFCASSPAGNTLYF 0.0 \n", + "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 \n", + "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 \n", + "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 \n", + "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 \n", + "... ... \n", + "CAYRSGEYGNKLVFCASSMAGSSYEQYF 1.0 \n", + "CAYRSFNNNDMRFCASRSRGGHSPLHF 1.0 \n", + "CATDNDMRFCASSFGPDEQYF 1.0 \n", + "CAVLNNARLMFCASSVDRVADTQYF 1.0 \n", + "CAMTSFQKLVFCASSLRGEKNNYGYTF 1.0 \n", + "\n", + "[14106 rows x 508 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
TRAV_TCRAV12-1TRAV_TCRAV17TRAV_TCRAV19TRAV_TCRAV21TRAV_TCRAV23/DV6TRAV_TCRAV3TRAV_TCRAV38-1TRAV_TCRAV38-2/DV8TRAV_TCRAV41TRAV_TRAV-2...TRBJ_TRBJ2-5TRBJ_TRBJ2-6TRBJ_TRBJ2-7TRBJ_TRBJ2-7TRBJ_TRBJ20-1TRBJ_TRBJ24-1TRBJ_TRBJ38-2/DV8TRBJ_TRBJ5-1TRBJ_TRBJ5-6TRBJ_nan
cdr3
CAGGADRLTFCASSPAGNTLYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
CAASGGSNYNVLYFCAWSLWGGPSAETLYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
CAASYNYAQGLTFCASRDWGGRQDTQYF0.00.00.00.00.00.00.00.00.00.0...1.00.00.00.00.00.00.00.00.00.0
CAAQTGNYKYVFCASGDAGTGQDTQYF0.00.00.00.00.00.00.00.00.00.0...1.00.00.00.00.00.00.00.00.00.0
CAASLTGGYKVVFCAWRTDNQDTQYF0.00.00.00.00.00.00.00.00.00.0...1.00.00.00.00.00.00.00.00.00.0
..................................................................
CAYRSGEYGNKLVFCASSMAGSSYEQYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
CAYRSFNNNDMRFCASRSRGGHSPLHF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
CATDNDMRFCASSFGPDEQYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
CAVLNNARLMFCASSVDRVADTQYF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
CAMTSFQKLVFCASSLRGEKNNYGYTF0.00.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
\n", + "

14106 rows × 508 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "df_encoded_TCR" + } + }, + "metadata": {}, + "execution_count": 5 + } + ], + "source": [ + "##################### encode additional information for the TCRs #####################\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "\n", + "# One hot encoding of categorical variables\n", + "columns_to_encode = ['TRAV','TRAJ','TRBV','TRBJ']\n", + "df = combined_df.loc[:,columns_to_encode]\n", + "\n", + "one_hot_encoder = OneHotEncoder(drop='first', sparse_output=False)\n", + "# Fit and transform the data\n", + "one_hot_encoded = one_hot_encoder.fit_transform(df)\n", + "feature_names = one_hot_encoder.get_feature_names_out(df.columns)\n", + "df_encoded = pd.DataFrame(one_hot_encoded, columns=feature_names)\n", + "\n", + "df_encoded.index = combined_df[\"cdr3\"]\n", + "df_encoded_TCR = df_encoded[~df_encoded.index.duplicated(keep=\"first\")]\n", + "df_encoded_TCR" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "vUdehGAmrGra" + }, + "outputs": [], + "source": [ + "###################################### enter parameters ######################################\n", + "# Setting_new_epitope_new_TCR_LOOCV\n", + "\n", + "setting = \"Setting_new_epitope_new_TCR_FINAL_MODEL\"\n", + "\n", + "# \"ESM3 + VJ genes\" \"all features\" \"ESMonly\" \"withoutESM\"\n", + "features_name = \"ESM3 + VJ genes\"\n", + "\n", + "# MHCI MHCII all\n", + "MHC_class = \"all\"\n", + "\n", + "species = \"all\" # HomoSapiens all\n", + "tcr_features = \"cdr3\"\n", + "repetition = 5\n", + "algorithm = \"sklearn_logit\"\n", + "result_folder = \"MixTCRpred/output/\"+setting+\"/\"\n", + "nfolds = 5 # here it is only for Gridsearch\n", + "n_jobs = -1\n", + "\n", + "os.chdir(path)\n", + "os.makedirs(result_folder,exist_ok=True)\n", + "os.chdir(result_folder)\n", + "\n", + "if species != \"all\":\n", + " combined_df = combined_df.loc[combined_df[\"species\"]==species,:]\n", + "\n", + "if MHC_class != \"all\":\n", + " combined_df = combined_df.loc[combined_df[\"MHC_class\"]==MHC_class, : ]\n", + "\n", + "\n", + "############################################ percentile background peptides ############################################\n", + "p_test_mean_background = pd.read_csv(\"selected_background_peptides_embeddings_\"+features_name+\"_\"+model_name+'_predicted_prob.csv', index_col=0)\n" + ] + }, + { + "cell_type": "code", + "source": [ + "combined_df.MHC_class.value_counts()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 198 + }, + "id": "C9NE89S-N92J", + "outputId": "e4ba3e64-5038-4e1d-b523-197f2ed57361" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "MHC_class\n", + "MHCI 13248\n", + "MHCII 4428\n", + "Name: count, dtype: int64" + ], + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
count
MHC_class
MHCI13248
MHCII4428
\n", + "

" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ] + }, + { + "cell_type": "code", + "source": [ + "# load model\n", + "import pickle\n", + "pickle_off = open(\"model_\"+features_name+\".pickle\",\"rb\")\n", + "model = pickle.load(pickle_off)\n" + ], + "metadata": { + "id": "JSI2yTCzryxf" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "####################################### INPUT: list of epitope sequences #######################################\n", + "\n", + "epitopes = [\"DIYKGMGPLLATVFKSV\",\"GMGPLLATVFKSV\"]\n", + "\n", + "MHC = \"H2-IAb\" # for now, these values do not matter since the model does not incorporate them !\n", + "MHC_class=\"MHCII\" # for now, these values do not matter since the model does not incorporate them !\n", + "species = \"MusMusculus\"# for now, these values do not matter since the model does not incorporate them !\n", + "\n", + "tcr = np.unique(combined_df[tcr_features])\n", + "selected_tcrs = random.sample(list(tcr), min(10000, len(tcr)))\n", + "\n", + "mat_test_tab = pd.DataFrame()\n", + "for epitope in epitopes:\n", + " print(epitope)\n", + " df = pd.DataFrame( {tcr_features : selected_tcrs , \"epitope\": epitope} )\n", + " mat_test_tab = pd.concat( [ mat_test_tab , df ] ).reset_index(drop=True)\n", + "mat_test_tab" + ], + "metadata": { + "id": "MGST9ARllAFr", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 480 + }, + "outputId": "56f8624f-92c0-4a70-ba24-9420ea3317fd" + }, + "execution_count": 10, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "DIYKGMGPLLATVFKSV\n", + "GMGPLLATVFKSV\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " cdr3 epitope\n", + "0 CAMRNNVGDNSKLIWCASGDAGWSNQDTQYF DIYKGMGPLLATVFKSV\n", + "1 CAWRGGGGADGLTFCASSWDPTYNEQFF DIYKGMGPLLATVFKSV\n", + "2 CILREGFGNVLHCCASSMRSGSEQFF DIYKGMGPLLATVFKSV\n", + "3 CAASIGNNRIFFCAWSLQEDTQYF DIYKGMGPLLATVFKSV\n", + "4 CAVSSNTGKLIFCASSASRVGEDTQYF DIYKGMGPLLATVFKSV\n", + "... ... ...\n", + "19995 CAVINMGYKLTFCASEDWGGAHAEQFF GMGPLLATVFKSV\n", + "19996 CALSGYTEGADRLTFCASSERNSGNTLYF GMGPLLATVFKSV\n", + "19997 CVLSANNNAGAKLTFCASSDAAREGQNTLYF GMGPLLATVFKSV\n", + "19998 CAVRDQAGTALIFCASSFGPVEQYF GMGPLLATVFKSV\n", + "19999 CALSGFTDKLIFCASSLRDGYTDTQYF GMGPLLATVFKSV\n", + "\n", + "[20000 rows x 2 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
cdr3epitope
0CAMRNNVGDNSKLIWCASGDAGWSNQDTQYFDIYKGMGPLLATVFKSV
1CAWRGGGGADGLTFCASSWDPTYNEQFFDIYKGMGPLLATVFKSV
2CILREGFGNVLHCCASSMRSGSEQFFDIYKGMGPLLATVFKSV
3CAASIGNNRIFFCAWSLQEDTQYFDIYKGMGPLLATVFKSV
4CAVSSNTGKLIFCASSASRVGEDTQYFDIYKGMGPLLATVFKSV
.........
19995CAVINMGYKLTFCASEDWGGAHAEQFFGMGPLLATVFKSV
19996CALSGYTEGADRLTFCASSERNSGNTLYFGMGPLLATVFKSV
19997CVLSANNNAGAKLTFCASSDAAREGQNTLYFGMGPLLATVFKSV
19998CAVRDQAGTALIFCASSFGPVEQYFGMGPLLATVFKSV
19999CALSGFTDKLIFCASSLRDGYTDTQYFGMGPLLATVFKSV
\n", + "

20000 rows × 2 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "mat_test_tab", + "summary": "{\n \"name\": \"mat_test_tab\",\n \"rows\": 20000,\n \"fields\": [\n {\n \"column\": \"cdr3\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 10000,\n \"samples\": [\n \"CAASNNRIFFCAWTGGIDEQYF\",\n \"CVVNDPSGNTPLVFCSARDERALNTGELFF\",\n \"CGNTGNQFYFCASSQGVGYTF\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"epitope\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"GMGPLLATVFKSV\",\n \"DIYKGMGPLLATVFKSV\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 10 + } + ] + }, + { + "cell_type": "code", + "source": [ + "\n", + "################################# encoding epitope #################################\n", + "\n", + "epitope_embeddings = pd.DataFrame()\n", + "for epitope in epitopes:\n", + " print(epitope)\n", + " results = run_ESM3(epitope,ESM3_model)\n", + "\n", + " embed_total = pd.concat([results[0]], axis=1).transpose()\n", + " sequence_left = [results[1]]\n", + " embed_total.index = sequence_left\n", + " # Convert column names to strings before adding \"ESM3_\"\n", + " embed_total.columns = \"ESM3_\" + embed_total.columns.astype(str)\n", + " epitope_embeddings = pd.concat( [ epitope_embeddings , embed_total.copy() ] )\n", + "epitope_embeddings" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 197 + }, + "id": "SkvF4abSQmhO", + "outputId": "aac0c0f6-bc7b-4c53-9082-6a1b96c17b38" + }, + "execution_count": 11, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "DIYKGMGPLLATVFKSV\n", + "GMGPLLATVFKSV\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " ESM3_0 ESM3_1 ESM3_2 ESM3_3 ESM3_4 ESM3_5 \\\n", + "DIYKGMGPLLATVFKSV 0.023491 -0.103977 -0.012792 0.017144 0.038514 -0.093617 \n", + "GMGPLLATVFKSV 0.083255 -0.080921 -0.009170 0.004828 0.053290 -0.087519 \n", + "\n", + " ESM3_6 ESM3_7 ESM3_8 ESM3_9 ... ESM3_1526 \\\n", + "DIYKGMGPLLATVFKSV -0.049821 -0.009365 0.126688 -0.118027 ... 0.127035 \n", + "GMGPLLATVFKSV -0.056259 -0.005152 0.104326 -0.114268 ... 0.150475 \n", + "\n", + " ESM3_1527 ESM3_1528 ESM3_1529 ESM3_1530 ESM3_1531 \\\n", + "DIYKGMGPLLATVFKSV 0.001517 0.115759 0.041184 0.046807 -0.103359 \n", + "GMGPLLATVFKSV 0.080277 0.104821 -0.016300 0.032343 -0.084724 \n", + "\n", + " ESM3_1532 ESM3_1533 ESM3_1534 ESM3_1535 \n", + "DIYKGMGPLLATVFKSV 0.081879 0.041203 -0.049147 -0.609815 \n", + "GMGPLLATVFKSV 0.076475 0.116125 -0.067687 -0.679018 \n", + "\n", + "[2 rows x 1536 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
ESM3_0ESM3_1ESM3_2ESM3_3ESM3_4ESM3_5ESM3_6ESM3_7ESM3_8ESM3_9...ESM3_1526ESM3_1527ESM3_1528ESM3_1529ESM3_1530ESM3_1531ESM3_1532ESM3_1533ESM3_1534ESM3_1535
DIYKGMGPLLATVFKSV0.023491-0.103977-0.0127920.0171440.038514-0.093617-0.049821-0.0093650.126688-0.118027...0.1270350.0015170.1157590.0411840.046807-0.1033590.0818790.041203-0.049147-0.609815
GMGPLLATVFKSV0.083255-0.080921-0.0091700.0048280.053290-0.087519-0.056259-0.0051520.104326-0.114268...0.1504750.0802770.104821-0.0163000.032343-0.0847240.0764750.116125-0.067687-0.679018
\n", + "

2 rows × 1536 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "epitope_embeddings" + } + }, + "metadata": {}, + "execution_count": 11 + } + ] + }, + { + "cell_type": "code", + "source": [ + "#################### encode additional information for the epitopes ####################\n", + "from sklearn.preprocessing import OneHotEncoder\n", + "\n", + "# One hot encoding of categorical variables\n", + "columns_to_encode = ['MHC','MHC_class','species']\n", + "df = combined_df.loc[:,columns_to_encode]\n", + "\n", + "one_hot_encoder = OneHotEncoder(drop='first', sparse_output=False)\n", + "# Fit and transform the data\n", + "one_hot_encoded = one_hot_encoder.fit_transform(df)\n", + "feature_names = one_hot_encoder.get_feature_names_out(df.columns)\n", + "df_encoded = pd.DataFrame(one_hot_encoded, columns=feature_names)\n", + "\n", + "df_encoded.index = combined_df[\"epitope\"]\n", + "df_encoded_reference = df_encoded[~df_encoded.index.duplicated(keep=\"first\")]\n", + "\n", + "# Create a new row with all zeros and index name as epitope\n", + "df_encoded_epitope = pd.DataFrame()\n", + "for epitope in epitopes:\n", + " print(epitope)\n", + " new_row = pd.DataFrame(index=[epitope],columns=df_encoded_reference.columns)\n", + " df_encoded_reference_temp = pd.concat([new_row, df_encoded_reference])\n", + "\n", + "\n", + " if \"MHC_\" + MHC not in df_encoded_reference_temp.columns:\n", + " cols_to_zero = [col for col in df_encoded_reference_temp.columns if \"MHC_\" in col]\n", + " df_encoded_reference_temp[cols_to_zero] = 0\n", + " else:\n", + " df_encoded_reference_temp.loc[epitope,[\"MHC_\"+MHC]] = 1\n", + "\n", + " if \"MHC_class_\" + MHC_class not in df_encoded_reference_temp.columns:\n", + " cols_to_zero = [col for col in df_encoded_reference_temp.columns if \"MHC_class_\" in col]\n", + " df_encoded_reference_temp[cols_to_zero] = 0\n", + " else:\n", + " df_encoded_reference_temp.loc[epitope,[\"MHC_class_\"+MHC_class]] = 1\n", + "\n", + " if \"species_\" + species not in df_encoded_reference_temp.columns:\n", + " cols_to_zero = [col for col in df_encoded_reference_temp.columns if \"species_\" in col]\n", + " df_encoded_reference_temp[cols_to_zero] = 0\n", + " else:\n", + " df_encoded_reference_temp.loc[epitope,[\"species_\"+species]] = 1\n", + "\n", + "\n", + " df_encoded_reference_temp = df_encoded_reference_temp.fillna(0)\n", + " df_encoded_epitope = pd.concat( [ df_encoded_epitope , df_encoded_reference_temp.iloc[[0]] ] )\n", + "\n", + "df_encoded_epitope\n", + "\n", + "" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 304 + }, + "id": "89a8vzOOQ1f8", + "outputId": "5989455f-d55b-47b1-f438-7601f56bdd7f" + }, + "execution_count": 12, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "DIYKGMGPLLATVFKSV\n", + "GMGPLLATVFKSV\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":22: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n", + " df_encoded_reference_temp = pd.concat([new_row, df_encoded_reference])\n", + ":22: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n", + " df_encoded_reference_temp = pd.concat([new_row, df_encoded_reference])\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " MHC_H2-Db MHC_H2-IAb MHC_H2-IEk MHC_H2-Kb MHC_H2-Kd \\\n", + "DIYKGMGPLLATVFKSV 0.0 1.0 0.0 0.0 0.0 \n", + "GMGPLLATVFKSV 0.0 1.0 0.0 0.0 0.0 \n", + "\n", + " MHC_H2-Ld MHC_HLA-A*02:01 MHC_HLA-A*08:01 \\\n", + "DIYKGMGPLLATVFKSV 0.0 0.0 0.0 \n", + "GMGPLLATVFKSV 0.0 0.0 0.0 \n", + "\n", + " MHC_HLA-A*11:01 MHC_HLA-A*24:02 ... \\\n", + "DIYKGMGPLLATVFKSV 0.0 0.0 ... \n", + "GMGPLLATVFKSV 0.0 0.0 ... \n", + "\n", + " MHC_HLA-DQA1:02/DQB1*06:02 MHC_HLA-DRA:01 \\\n", + "DIYKGMGPLLATVFKSV 0.0 0.0 \n", + "GMGPLLATVFKSV 0.0 0.0 \n", + "\n", + " MHC_HLA-DRA:01/DRB1:01 MHC_HLA-DRB1*04:01 \\\n", + "DIYKGMGPLLATVFKSV 0.0 0.0 \n", + "GMGPLLATVFKSV 0.0 0.0 \n", + "\n", + " MHC_HLA-DRB1*04:05 MHC_HLA-DRB1*07:01 MHC_HLA-DRB1*11:01 \\\n", + "DIYKGMGPLLATVFKSV 0.0 0.0 0.0 \n", + "GMGPLLATVFKSV 0.0 0.0 0.0 \n", + "\n", + " MHC_HLA-DRB1:01 MHC_class_MHCII species_MusMusculus \n", + "DIYKGMGPLLATVFKSV 0.0 1.0 1.0 \n", + "GMGPLLATVFKSV 0.0 1.0 1.0 \n", + "\n", + "[2 rows x 37 columns]" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
MHC_H2-DbMHC_H2-IAbMHC_H2-IEkMHC_H2-KbMHC_H2-KdMHC_H2-LdMHC_HLA-A*02:01MHC_HLA-A*08:01MHC_HLA-A*11:01MHC_HLA-A*24:02...MHC_HLA-DQA1:02/DQB1*06:02MHC_HLA-DRA:01MHC_HLA-DRA:01/DRB1:01MHC_HLA-DRB1*04:01MHC_HLA-DRB1*04:05MHC_HLA-DRB1*07:01MHC_HLA-DRB1*11:01MHC_HLA-DRB1:01MHC_class_MHCIIspecies_MusMusculus
DIYKGMGPLLATVFKSV0.01.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.01.01.0
GMGPLLATVFKSV0.01.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.01.01.0
\n", + "

2 rows × 37 columns

\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "df_encoded_epitope" + } + }, + "metadata": {}, + "execution_count": 12 + } + ] + }, + { + "cell_type": "code", + "source": [ + "############################################### run prediction on epitope ###############################################\n", + "\n", + "p_test = run_prediction(mat_test_tab,epitope_embeddings,cdr3_embeddings)\n", + "" + ], + "metadata": { + "id": "As-zxbIKkmTO" + }, + "execution_count": 13, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "############################################# average predicted probability #############################################\n", + "\n", + "p_test = p_test.loc[:,[\"epitope\",\"predicted_prob\"]]\n", + "p_test_mean = p_test.groupby('epitope').agg('mean')\n", + "p_test_mean" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 164 + }, + "id": "osqix4p2sTtq", + "outputId": "18ffe67c-7470-403e-c30a-bc53ed3bfd14" + }, + "execution_count": 14, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " predicted_prob\n", + "epitope \n", + "DIYKGMGPLLATVFKSV 0.880342\n", + "GMGPLLATVFKSV 0.604100" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
predicted_prob
epitope
DIYKGMGPLLATVFKSV0.880342
GMGPLLATVFKSV0.604100
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "p_test_mean", + "summary": "{\n \"name\": \"p_test_mean\",\n \"rows\": 2,\n \"fields\": [\n {\n \"column\": \"epitope\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"GMGPLLATVFKSV\",\n \"DIYKGMGPLLATVFKSV\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"predicted_prob\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.19533245581349337,\n \"min\": 0.6041002243846824,\n \"max\": 0.880342032567768,\n \"num_unique_values\": 2,\n \"samples\": [\n 0.6041002243846824,\n 0.880342032567768\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 14 + } + ] + }, + { + "cell_type": "code", + "source": [ + "################################################ compute Percentile Rank ################################################\n", + "from scipy import stats\n", + "percentile_rank_all = pd.DataFrame()\n", + "for epitope in p_test_mean.index:\n", + "\n", + " # Calculate the mean predicted probability for the test set\n", + " mean_predicted_prob = p_test_mean.loc[epitope,\"predicted_prob\"]\n", + "\n", + " # Calculate the percentile rank\n", + " percentile_rank = round(100 - stats.percentileofscore(p_test_mean_background[\"predicted_prob\"], mean_predicted_prob), 2 )\n", + "\n", + " df = pd.DataFrame([epitope, percentile_rank]).transpose()\n", + " percentile_rank_all = pd.concat( [ percentile_rank_all , df ] )\n", + "\n", + " # Print the percentile rank\n", + " print(f\"Percentile Rank: {percentile_rank}\")\n", + "\n", + "# Percentile Rank from 0 to 100. the closer to 0 the stronger the predicted TCR recognition.\n", + "percentile_rank_all.columns = [\"epitope\",\"Percentile Rank\"]\n", + "percentile_rank_all" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 181 + }, + "id": "DS5pQUWV26ll", + "outputId": "e553d1ba-1c01-4387-b57b-1f018d145890" + }, + "execution_count": 22, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Percentile Rank: 1.25\n", + "Percentile Rank: 16.15\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " epitope Percentile Rank\n", + "0 DIYKGMGPLLATVFKSV 1.25\n", + "0 GMGPLLATVFKSV 16.15" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
epitopePercentile Rank
0DIYKGMGPLLATVFKSV1.25
0GMGPLLATVFKSV16.15
\n", + "
\n", + "
\n", + "\n", + "
\n", + " \n", + "\n", + " \n", + "\n", + " \n", + "
\n", + "\n", + "\n", + "
\n", + " \n", + "\n", + "\n", + "\n", + " \n", + "
\n", + "\n", + "
\n", + " \n", + " \n", + " \n", + "
\n", + "\n", + "
\n", + "
\n" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "dataframe", + "variable_name": "percentile_rank_all", + "summary": "{\n \"name\": \"percentile_rank_all\",\n \"rows\": 2,\n \"fields\": [\n {\n \"column\": \"epitope\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"GMGPLLATVFKSV\",\n \"DIYKGMGPLLATVFKSV\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Percentile Rank\",\n \"properties\": {\n \"dtype\": \"date\",\n \"min\": 1.25,\n \"max\": 16.15,\n \"num_unique_values\": 2,\n \"samples\": [\n 16.15,\n 1.25\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}" + } + }, + "metadata": {}, + "execution_count": 22 + } + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "machine_shape": "hm" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file