From cd997185ba959a44746a79068abea576efdafeb1 Mon Sep 17 00:00:00 2001
From: Mi Yang <11358894+Sophon-0@users.noreply.github.com>
Date: Thu, 17 Oct 2024 13:53:20 -0400
Subject: [PATCH] prediction of TCR recognition in percentile rank for a given
list of epitopes
---
...diction_NewEpitope_NewTCR_PREDICTION.ipynb | 3164 +++++++++++++++++
1 file changed, 3164 insertions(+)
create mode 100644 notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_PREDICTION.ipynb
diff --git a/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_PREDICTION.ipynb b/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_PREDICTION.ipynb
new file mode 100644
index 0000000..a72a32c
--- /dev/null
+++ b/notebooks/TCR_epitope_prediction_NewEpitope_NewTCR_PREDICTION.ipynb
@@ -0,0 +1,3164 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "id": "-NSHxGYaV7vc"
+ },
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "!pip install esm\n",
+ "\n",
+ "import os\n",
+ "from google.colab import drive\n",
+ "drive.mount('/content/drive')\n",
+ "path = \"/content/drive/MyDrive/COLAB/TCR_projects\"\n",
+ "os.chdir(path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from esm.sdk import client\n",
+ "from getpass import getpass\n",
+ "\n",
+ "model_name = \"esm3-small-2024-08\"\n",
+ "\n",
+ "# Evolutionary Scale API key Secret: 6LIhzrmbfzsd1rEFMmnZkc\n",
+ "token = getpass(\"Token from Forge console: \")\n",
+ "ESM3_model = client(\n",
+ " model=model_name, # https://forge.evolutionaryscale.ai/console\n",
+ " url=\"https://forge.evolutionaryscale.ai\",\n",
+ " token=token,\n",
+ ")\n"
+ ],
+ "metadata": {
+ "id": "s3ckojV6roA2",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "7f568e22-9db0-494f-e7a8-b79ccbc7793e"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.10/dist-packages/google/cloud/storage/transfer_manager.py:30: UserWarning: The module `transfer_manager` is a preview feature. Functionality and API may change. This warning will be removed in a future release.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Token from Forge console: ··········\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "id": "be5dH_JwZLlu"
+ },
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "\n",
+ "from sklearn.model_selection import GridSearchCV\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.ensemble import RandomForestClassifier\n",
+ "from sklearn.neural_network import MLPClassifier\n",
+ "from sklearn.preprocessing import StandardScaler\n",
+ "\n",
+ "def preprocess_features(feat, res, train_indices, test_indices):\n",
+ " x_train = feat.iloc[train_indices, :]\n",
+ " y_train = res[train_indices]\n",
+ " x_test = feat.iloc[test_indices, :]\n",
+ " y_test = res[test_indices]\n",
+ " # scale the data\n",
+ " scaler = StandardScaler().fit(x_train)\n",
+ " x_train = pd.DataFrame(scaler.transform(x_train), index=x_train.index, columns=x_train.columns)\n",
+ " x_test = pd.DataFrame(scaler.transform(x_test), index=x_test.index, columns=x_test.columns)\n",
+ " return x_train, y_train, x_test, y_test\n",
+ "\n",
+ "\n",
+ "def run_prediction(mat_test_tab,epitope_embeddings,tcr_embeddings):\n",
+ "\n",
+ " def get_embeddings(row):\n",
+ " epitope = epitope_embeddings.loc[row['epitope']].values\n",
+ " tcr = tcr_embeddings.loc[row[tcr_features]].values\n",
+ " return np.concatenate((epitope, tcr))\n",
+ "\n",
+ " ################# test set features\n",
+ " features_test = mat_test_tab.apply(get_embeddings, axis=1)\n",
+ " features_test = pd.DataFrame(features_test.tolist(), index=features_test.index)\n",
+ " features_test.index = mat_test_tab[\"epitope\"] + \"_\" + mat_test_tab[tcr_features]\n",
+ " features_test.columns = epitope_embeddings.columns.tolist() + tcr_embeddings.columns.tolist()\n",
+ " ## add other information\n",
+ " df_encoded_TCR_subset = df_encoded_TCR.loc[mat_test_tab[tcr_features], : ]\n",
+ " df_encoded_epitope_subset = df_encoded_epitope.loc[mat_test_tab[\"epitope\"], : ]\n",
+ "\n",
+ " ## combine\n",
+ " # \"ESM3 + VJ genes\" \"all features\" \"ESMonly\" \"withoutESM\"\n",
+ " if features_name == \"ESM3 + VJ genes\":\n",
+ " features_test_all = pd.concat([features_test.reset_index(drop=True), df_encoded_TCR_subset.reset_index(drop=True)], axis=1)\n",
+ " if features_name == \"all features\":\n",
+ " features_test_all = pd.concat([features_test.reset_index(drop=True), df_encoded_TCR_subset.reset_index(drop=True), df_encoded_epitope_subset.reset_index(drop=True)], axis=1)\n",
+ " if features_name == \"ESMonly\":\n",
+ " features_test_all = features_test\n",
+ " if features_name == \"withoutESM\":\n",
+ " features_test_all = pd.concat([df_encoded_TCR_subset.reset_index(drop=True), df_encoded_epitope_subset.reset_index(drop=True)], axis=1)\n",
+ "\n",
+ " features_test_all.index = features_test.index\n",
+ "\n",
+ " ############################################ run ML ############################################\n",
+ " X_test = features_test_all\n",
+ " X_test.columns = X_test.columns.astype(str)\n",
+ "\n",
+ " x_test_proba = model.predict_proba(X_test)[:, 1]\n",
+ "\n",
+ " p_test = pd.DataFrame(\n",
+ " { 'split': \"test\",\n",
+ " 'epitope': mat_test_tab.epitope,\n",
+ " 'sample': X_test.index,\n",
+ " 'predicted_prob': x_test_proba\n",
+ " }\n",
+ " )\n",
+ "\n",
+ " return p_test\n",
+ "\n",
+ "\n",
+ "\n",
+ "def run_ESM3(seq,model):\n",
+ " try:\n",
+ " from esm.models.esm3 import ESM3\n",
+ " from esm.sdk.api import ESMProtein, SamplingConfig\n",
+ " from esm.utils.constants.models import ESM3_OPEN_SMALL\n",
+ "\n",
+ " # Create an ESMProtein object\n",
+ " protein = ESMProtein(sequence=seq)\n",
+ "\n",
+ " # Encode the protein\n",
+ " protein_tensor = model.encode(protein)\n",
+ "\n",
+ " # Get the embeddings\n",
+ " output = model.forward_and_sample(\n",
+ " protein_tensor,\n",
+ " SamplingConfig(return_per_residue_embeddings=True)\n",
+ " )\n",
+ "\n",
+ " # aggregate the per residue embedding\n",
+ " df = pd.DataFrame(output.per_residue_embedding)\n",
+ " column_means = df.mean(axis=0)\n",
+ " return column_means.transpose(), seq\n",
+ "\n",
+ " except Exception as e:\n",
+ " print(f\"Error processing sequence: {seq}\") # Print the problematic sequence\n",
+ " print(e) # Print the exception details\n",
+ " return None # Or handle the error differently\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "id": "eB-K1R38U_Y3"
+ },
+ "outputs": [],
+ "source": [
+ "######################################## TCR-Epitope Binding Affinity Prediction Task #################################\n",
+ "os.chdir(path)\n",
+ "combined_df = pd.read_csv(\"MixTCRpred/full_training_set_146pmhc.csv\")\n",
+ "\n",
+ "# combine cdr3\n",
+ "combined_df[\"cdr3\"] = combined_df[\"cdr3_TRA\"] + combined_df[\"cdr3_TRB\"]\n",
+ "combined_df[\"value\"] = 1\n",
+ "combined_df.index = combined_df[\"epitope\"] + \"_\" + combined_df[\"cdr3\"]\n",
+ "combined_df\n",
+ "\n",
+ "##################################################### choose ESM model #################################################\n",
+ "\n",
+ "# \"esm3-small-2024-08\" \"esm2_t6_8M_UR50D\"\n",
+ "\n",
+ "model_name = \"esm3-small-2024-08\"\n",
+ "epitope_embeddings = pd.read_csv('MixTCRpred/data/epitope_embeddings_'+model_name+'.csv',index_col=0)\n",
+ "cdr3_embeddings = pd.read_csv('MixTCRpred/data/cdr3_embeddings_'+model_name+'.csv',index_col=0)\n",
+ "\n",
+ "############################################# subset of available embeddings ###########################################\n",
+ "combined_df = combined_df.loc[combined_df[\"epitope\"].isin(epitope_embeddings.index) , :]\n",
+ "combined_df = combined_df.loc[combined_df[\"cdr3\"].isin(cdr3_embeddings.index) , :]\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 493
+ },
+ "id": "UrSmETAb3b3W",
+ "outputId": "34cd4ed9-558f-4dcd-86e5-6c6a883300ec"
+ },
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " TRAV_TCRAV12-1 TRAV_TCRAV17 TRAV_TCRAV19 \\\n",
+ "cdr3 \n",
+ "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n",
+ "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n",
+ "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n",
+ "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n",
+ "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n",
+ "... ... ... ... \n",
+ "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n",
+ "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n",
+ "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n",
+ "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n",
+ "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n",
+ "\n",
+ " TRAV_TCRAV21 TRAV_TCRAV23/DV6 TRAV_TCRAV3 \\\n",
+ "cdr3 \n",
+ "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n",
+ "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n",
+ "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n",
+ "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n",
+ "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n",
+ "... ... ... ... \n",
+ "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n",
+ "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n",
+ "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n",
+ "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n",
+ "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n",
+ "\n",
+ " TRAV_TCRAV38-1 TRAV_TCRAV38-2/DV8 \\\n",
+ "cdr3 \n",
+ "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 \n",
+ "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 \n",
+ "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 \n",
+ "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 \n",
+ "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 \n",
+ "... ... ... \n",
+ "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 \n",
+ "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 \n",
+ "CATDNDMRFCASSFGPDEQYF 0.0 0.0 \n",
+ "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 \n",
+ "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 \n",
+ "\n",
+ " TRAV_TCRAV41 TRAV_TRAV-2 ... TRBJ_TRBJ2-5 \\\n",
+ "cdr3 ... \n",
+ "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 ... 0.0 \n",
+ "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 ... 0.0 \n",
+ "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 ... 1.0 \n",
+ "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 ... 1.0 \n",
+ "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 ... 1.0 \n",
+ "... ... ... ... ... \n",
+ "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 ... 0.0 \n",
+ "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 ... 0.0 \n",
+ "CATDNDMRFCASSFGPDEQYF 0.0 0.0 ... 0.0 \n",
+ "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 ... 0.0 \n",
+ "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 ... 0.0 \n",
+ "\n",
+ " TRBJ_TRBJ2-6 TRBJ_TRBJ2-7 TRBJ_TRBJ2-7 \\\n",
+ "cdr3 \n",
+ "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n",
+ "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n",
+ "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n",
+ "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n",
+ "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n",
+ "... ... ... ... \n",
+ "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n",
+ "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n",
+ "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n",
+ "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n",
+ "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n",
+ "\n",
+ " TRBJ_TRBJ20-1 TRBJ_TRBJ24-1 \\\n",
+ "cdr3 \n",
+ "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 \n",
+ "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 \n",
+ "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 \n",
+ "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 \n",
+ "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 \n",
+ "... ... ... \n",
+ "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 \n",
+ "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 \n",
+ "CATDNDMRFCASSFGPDEQYF 0.0 0.0 \n",
+ "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 \n",
+ "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 \n",
+ "\n",
+ " TRBJ_TRBJ38-2/DV8 TRBJ_TRBJ5-1 TRBJ_TRBJ5-6 \\\n",
+ "cdr3 \n",
+ "CAGGADRLTFCASSPAGNTLYF 0.0 0.0 0.0 \n",
+ "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 0.0 0.0 \n",
+ "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 0.0 0.0 \n",
+ "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 0.0 0.0 \n",
+ "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 0.0 0.0 \n",
+ "... ... ... ... \n",
+ "CAYRSGEYGNKLVFCASSMAGSSYEQYF 0.0 0.0 0.0 \n",
+ "CAYRSFNNNDMRFCASRSRGGHSPLHF 0.0 0.0 0.0 \n",
+ "CATDNDMRFCASSFGPDEQYF 0.0 0.0 0.0 \n",
+ "CAVLNNARLMFCASSVDRVADTQYF 0.0 0.0 0.0 \n",
+ "CAMTSFQKLVFCASSLRGEKNNYGYTF 0.0 0.0 0.0 \n",
+ "\n",
+ " TRBJ_nan \n",
+ "cdr3 \n",
+ "CAGGADRLTFCASSPAGNTLYF 0.0 \n",
+ "CAASGGSNYNVLYFCAWSLWGGPSAETLYF 0.0 \n",
+ "CAASYNYAQGLTFCASRDWGGRQDTQYF 0.0 \n",
+ "CAAQTGNYKYVFCASGDAGTGQDTQYF 0.0 \n",
+ "CAASLTGGYKVVFCAWRTDNQDTQYF 0.0 \n",
+ "... ... \n",
+ "CAYRSGEYGNKLVFCASSMAGSSYEQYF 1.0 \n",
+ "CAYRSFNNNDMRFCASRSRGGHSPLHF 1.0 \n",
+ "CATDNDMRFCASSFGPDEQYF 1.0 \n",
+ "CAVLNNARLMFCASSVDRVADTQYF 1.0 \n",
+ "CAMTSFQKLVFCASSLRGEKNNYGYTF 1.0 \n",
+ "\n",
+ "[14106 rows x 508 columns]"
+ ],
+ "text/html": [
+ "\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " TRAV_TCRAV12-1 | \n",
+ " TRAV_TCRAV17 | \n",
+ " TRAV_TCRAV19 | \n",
+ " TRAV_TCRAV21 | \n",
+ " TRAV_TCRAV23/DV6 | \n",
+ " TRAV_TCRAV3 | \n",
+ " TRAV_TCRAV38-1 | \n",
+ " TRAV_TCRAV38-2/DV8 | \n",
+ " TRAV_TCRAV41 | \n",
+ " TRAV_TRAV-2 | \n",
+ " ... | \n",
+ " TRBJ_TRBJ2-5 | \n",
+ " TRBJ_TRBJ2-6 | \n",
+ " TRBJ_TRBJ2-7 | \n",
+ " TRBJ_TRBJ2-7 | \n",
+ " TRBJ_TRBJ20-1 | \n",
+ " TRBJ_TRBJ24-1 | \n",
+ " TRBJ_TRBJ38-2/DV8 | \n",
+ " TRBJ_TRBJ5-1 | \n",
+ " TRBJ_TRBJ5-6 | \n",
+ " TRBJ_nan | \n",
+ "
\n",
+ " \n",
+ " cdr3 | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " CAGGADRLTFCASSPAGNTLYF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " CAASGGSNYNVLYFCAWSLWGGPSAETLYF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " CAASYNYAQGLTFCASRDWGGRQDTQYF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " CAAQTGNYKYVFCASGDAGTGQDTQYF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " CAASLTGGYKVVFCAWRTDNQDTQYF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " CAYRSGEYGNKLVFCASSMAGSSYEQYF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " CAYRSFNNNDMRFCASRSRGGHSPLHF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " CATDNDMRFCASSFGPDEQYF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " CAVLNNARLMFCASSVDRVADTQYF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " CAMTSFQKLVFCASSLRGEKNNYGYTF | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
14106 rows × 508 columns
\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "dataframe",
+ "variable_name": "df_encoded_TCR"
+ }
+ },
+ "metadata": {},
+ "execution_count": 5
+ }
+ ],
+ "source": [
+ "##################### encode additional information for the TCRs #####################\n",
+ "from sklearn.preprocessing import OneHotEncoder\n",
+ "\n",
+ "# One hot encoding of categorical variables\n",
+ "columns_to_encode = ['TRAV','TRAJ','TRBV','TRBJ']\n",
+ "df = combined_df.loc[:,columns_to_encode]\n",
+ "\n",
+ "one_hot_encoder = OneHotEncoder(drop='first', sparse_output=False)\n",
+ "# Fit and transform the data\n",
+ "one_hot_encoded = one_hot_encoder.fit_transform(df)\n",
+ "feature_names = one_hot_encoder.get_feature_names_out(df.columns)\n",
+ "df_encoded = pd.DataFrame(one_hot_encoded, columns=feature_names)\n",
+ "\n",
+ "df_encoded.index = combined_df[\"cdr3\"]\n",
+ "df_encoded_TCR = df_encoded[~df_encoded.index.duplicated(keep=\"first\")]\n",
+ "df_encoded_TCR"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "id": "vUdehGAmrGra"
+ },
+ "outputs": [],
+ "source": [
+ "###################################### enter parameters ######################################\n",
+ "# Setting_new_epitope_new_TCR_LOOCV\n",
+ "\n",
+ "setting = \"Setting_new_epitope_new_TCR_FINAL_MODEL\"\n",
+ "\n",
+ "# \"ESM3 + VJ genes\" \"all features\" \"ESMonly\" \"withoutESM\"\n",
+ "features_name = \"ESM3 + VJ genes\"\n",
+ "\n",
+ "# MHCI MHCII all\n",
+ "MHC_class = \"all\"\n",
+ "\n",
+ "species = \"all\" # HomoSapiens all\n",
+ "tcr_features = \"cdr3\"\n",
+ "repetition = 5\n",
+ "algorithm = \"sklearn_logit\"\n",
+ "result_folder = \"MixTCRpred/output/\"+setting+\"/\"\n",
+ "nfolds = 5 # here it is only for Gridsearch\n",
+ "n_jobs = -1\n",
+ "\n",
+ "os.chdir(path)\n",
+ "os.makedirs(result_folder,exist_ok=True)\n",
+ "os.chdir(result_folder)\n",
+ "\n",
+ "if species != \"all\":\n",
+ " combined_df = combined_df.loc[combined_df[\"species\"]==species,:]\n",
+ "\n",
+ "if MHC_class != \"all\":\n",
+ " combined_df = combined_df.loc[combined_df[\"MHC_class\"]==MHC_class, : ]\n",
+ "\n",
+ "\n",
+ "############################################ percentile background peptides ############################################\n",
+ "p_test_mean_background = pd.read_csv(\"selected_background_peptides_embeddings_\"+features_name+\"_\"+model_name+'_predicted_prob.csv', index_col=0)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "combined_df.MHC_class.value_counts()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 198
+ },
+ "id": "C9NE89S-N92J",
+ "outputId": "e4ba3e64-5038-4e1d-b523-197f2ed57361"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ "MHC_class\n",
+ "MHCI 13248\n",
+ "MHCII 4428\n",
+ "Name: count, dtype: int64"
+ ],
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " count | \n",
+ "
\n",
+ " \n",
+ " MHC_class | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " MHCI | \n",
+ " 13248 | \n",
+ "
\n",
+ " \n",
+ " MHCII | \n",
+ " 4428 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ]
+ },
+ "metadata": {},
+ "execution_count": 7
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# load model\n",
+ "import pickle\n",
+ "pickle_off = open(\"model_\"+features_name+\".pickle\",\"rb\")\n",
+ "model = pickle.load(pickle_off)\n"
+ ],
+ "metadata": {
+ "id": "JSI2yTCzryxf"
+ },
+ "execution_count": 8,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "####################################### INPUT: list of epitope sequences #######################################\n",
+ "\n",
+ "epitopes = [\"DIYKGMGPLLATVFKSV\",\"GMGPLLATVFKSV\"]\n",
+ "\n",
+ "MHC = \"H2-IAb\" # for now, these values do not matter since the model does not incorporate them !\n",
+ "MHC_class=\"MHCII\" # for now, these values do not matter since the model does not incorporate them !\n",
+ "species = \"MusMusculus\"# for now, these values do not matter since the model does not incorporate them !\n",
+ "\n",
+ "tcr = np.unique(combined_df[tcr_features])\n",
+ "selected_tcrs = random.sample(list(tcr), min(10000, len(tcr)))\n",
+ "\n",
+ "mat_test_tab = pd.DataFrame()\n",
+ "for epitope in epitopes:\n",
+ " print(epitope)\n",
+ " df = pd.DataFrame( {tcr_features : selected_tcrs , \"epitope\": epitope} )\n",
+ " mat_test_tab = pd.concat( [ mat_test_tab , df ] ).reset_index(drop=True)\n",
+ "mat_test_tab"
+ ],
+ "metadata": {
+ "id": "MGST9ARllAFr",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 480
+ },
+ "outputId": "56f8624f-92c0-4a70-ba24-9420ea3317fd"
+ },
+ "execution_count": 10,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "DIYKGMGPLLATVFKSV\n",
+ "GMGPLLATVFKSV\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " cdr3 epitope\n",
+ "0 CAMRNNVGDNSKLIWCASGDAGWSNQDTQYF DIYKGMGPLLATVFKSV\n",
+ "1 CAWRGGGGADGLTFCASSWDPTYNEQFF DIYKGMGPLLATVFKSV\n",
+ "2 CILREGFGNVLHCCASSMRSGSEQFF DIYKGMGPLLATVFKSV\n",
+ "3 CAASIGNNRIFFCAWSLQEDTQYF DIYKGMGPLLATVFKSV\n",
+ "4 CAVSSNTGKLIFCASSASRVGEDTQYF DIYKGMGPLLATVFKSV\n",
+ "... ... ...\n",
+ "19995 CAVINMGYKLTFCASEDWGGAHAEQFF GMGPLLATVFKSV\n",
+ "19996 CALSGYTEGADRLTFCASSERNSGNTLYF GMGPLLATVFKSV\n",
+ "19997 CVLSANNNAGAKLTFCASSDAAREGQNTLYF GMGPLLATVFKSV\n",
+ "19998 CAVRDQAGTALIFCASSFGPVEQYF GMGPLLATVFKSV\n",
+ "19999 CALSGFTDKLIFCASSLRDGYTDTQYF GMGPLLATVFKSV\n",
+ "\n",
+ "[20000 rows x 2 columns]"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " cdr3 | \n",
+ " epitope | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " CAMRNNVGDNSKLIWCASGDAGWSNQDTQYF | \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " CAWRGGGGADGLTFCASSWDPTYNEQFF | \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " CILREGFGNVLHCCASSMRSGSEQFF | \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " CAASIGNNRIFFCAWSLQEDTQYF | \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " CAVSSNTGKLIFCASSASRVGEDTQYF | \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " 19995 | \n",
+ " CAVINMGYKLTFCASEDWGGAHAEQFF | \n",
+ " GMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " 19996 | \n",
+ " CALSGYTEGADRLTFCASSERNSGNTLYF | \n",
+ " GMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " 19997 | \n",
+ " CVLSANNNAGAKLTFCASSDAAREGQNTLYF | \n",
+ " GMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " 19998 | \n",
+ " CAVRDQAGTALIFCASSFGPVEQYF | \n",
+ " GMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ " 19999 | \n",
+ " CALSGFTDKLIFCASSLRDGYTDTQYF | \n",
+ " GMGPLLATVFKSV | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
20000 rows × 2 columns
\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "dataframe",
+ "variable_name": "mat_test_tab",
+ "summary": "{\n \"name\": \"mat_test_tab\",\n \"rows\": 20000,\n \"fields\": [\n {\n \"column\": \"cdr3\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 10000,\n \"samples\": [\n \"CAASNNRIFFCAWTGGIDEQYF\",\n \"CVVNDPSGNTPLVFCSARDERALNTGELFF\",\n \"CGNTGNQFYFCASSQGVGYTF\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"epitope\",\n \"properties\": {\n \"dtype\": \"category\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"GMGPLLATVFKSV\",\n \"DIYKGMGPLLATVFKSV\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
+ }
+ },
+ "metadata": {},
+ "execution_count": 10
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "################################# encoding epitope #################################\n",
+ "\n",
+ "epitope_embeddings = pd.DataFrame()\n",
+ "for epitope in epitopes:\n",
+ " print(epitope)\n",
+ " results = run_ESM3(epitope,ESM3_model)\n",
+ "\n",
+ " embed_total = pd.concat([results[0]], axis=1).transpose()\n",
+ " sequence_left = [results[1]]\n",
+ " embed_total.index = sequence_left\n",
+ " # Convert column names to strings before adding \"ESM3_\"\n",
+ " embed_total.columns = \"ESM3_\" + embed_total.columns.astype(str)\n",
+ " epitope_embeddings = pd.concat( [ epitope_embeddings , embed_total.copy() ] )\n",
+ "epitope_embeddings"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 197
+ },
+ "id": "SkvF4abSQmhO",
+ "outputId": "aac0c0f6-bc7b-4c53-9082-6a1b96c17b38"
+ },
+ "execution_count": 11,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "DIYKGMGPLLATVFKSV\n",
+ "GMGPLLATVFKSV\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " ESM3_0 ESM3_1 ESM3_2 ESM3_3 ESM3_4 ESM3_5 \\\n",
+ "DIYKGMGPLLATVFKSV 0.023491 -0.103977 -0.012792 0.017144 0.038514 -0.093617 \n",
+ "GMGPLLATVFKSV 0.083255 -0.080921 -0.009170 0.004828 0.053290 -0.087519 \n",
+ "\n",
+ " ESM3_6 ESM3_7 ESM3_8 ESM3_9 ... ESM3_1526 \\\n",
+ "DIYKGMGPLLATVFKSV -0.049821 -0.009365 0.126688 -0.118027 ... 0.127035 \n",
+ "GMGPLLATVFKSV -0.056259 -0.005152 0.104326 -0.114268 ... 0.150475 \n",
+ "\n",
+ " ESM3_1527 ESM3_1528 ESM3_1529 ESM3_1530 ESM3_1531 \\\n",
+ "DIYKGMGPLLATVFKSV 0.001517 0.115759 0.041184 0.046807 -0.103359 \n",
+ "GMGPLLATVFKSV 0.080277 0.104821 -0.016300 0.032343 -0.084724 \n",
+ "\n",
+ " ESM3_1532 ESM3_1533 ESM3_1534 ESM3_1535 \n",
+ "DIYKGMGPLLATVFKSV 0.081879 0.041203 -0.049147 -0.609815 \n",
+ "GMGPLLATVFKSV 0.076475 0.116125 -0.067687 -0.679018 \n",
+ "\n",
+ "[2 rows x 1536 columns]"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " ESM3_0 | \n",
+ " ESM3_1 | \n",
+ " ESM3_2 | \n",
+ " ESM3_3 | \n",
+ " ESM3_4 | \n",
+ " ESM3_5 | \n",
+ " ESM3_6 | \n",
+ " ESM3_7 | \n",
+ " ESM3_8 | \n",
+ " ESM3_9 | \n",
+ " ... | \n",
+ " ESM3_1526 | \n",
+ " ESM3_1527 | \n",
+ " ESM3_1528 | \n",
+ " ESM3_1529 | \n",
+ " ESM3_1530 | \n",
+ " ESM3_1531 | \n",
+ " ESM3_1532 | \n",
+ " ESM3_1533 | \n",
+ " ESM3_1534 | \n",
+ " ESM3_1535 | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ " 0.023491 | \n",
+ " -0.103977 | \n",
+ " -0.012792 | \n",
+ " 0.017144 | \n",
+ " 0.038514 | \n",
+ " -0.093617 | \n",
+ " -0.049821 | \n",
+ " -0.009365 | \n",
+ " 0.126688 | \n",
+ " -0.118027 | \n",
+ " ... | \n",
+ " 0.127035 | \n",
+ " 0.001517 | \n",
+ " 0.115759 | \n",
+ " 0.041184 | \n",
+ " 0.046807 | \n",
+ " -0.103359 | \n",
+ " 0.081879 | \n",
+ " 0.041203 | \n",
+ " -0.049147 | \n",
+ " -0.609815 | \n",
+ "
\n",
+ " \n",
+ " GMGPLLATVFKSV | \n",
+ " 0.083255 | \n",
+ " -0.080921 | \n",
+ " -0.009170 | \n",
+ " 0.004828 | \n",
+ " 0.053290 | \n",
+ " -0.087519 | \n",
+ " -0.056259 | \n",
+ " -0.005152 | \n",
+ " 0.104326 | \n",
+ " -0.114268 | \n",
+ " ... | \n",
+ " 0.150475 | \n",
+ " 0.080277 | \n",
+ " 0.104821 | \n",
+ " -0.016300 | \n",
+ " 0.032343 | \n",
+ " -0.084724 | \n",
+ " 0.076475 | \n",
+ " 0.116125 | \n",
+ " -0.067687 | \n",
+ " -0.679018 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
2 rows × 1536 columns
\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "dataframe",
+ "variable_name": "epitope_embeddings"
+ }
+ },
+ "metadata": {},
+ "execution_count": 11
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "#################### encode additional information for the epitopes ####################\n",
+ "from sklearn.preprocessing import OneHotEncoder\n",
+ "\n",
+ "# One hot encoding of categorical variables\n",
+ "columns_to_encode = ['MHC','MHC_class','species']\n",
+ "df = combined_df.loc[:,columns_to_encode]\n",
+ "\n",
+ "one_hot_encoder = OneHotEncoder(drop='first', sparse_output=False)\n",
+ "# Fit and transform the data\n",
+ "one_hot_encoded = one_hot_encoder.fit_transform(df)\n",
+ "feature_names = one_hot_encoder.get_feature_names_out(df.columns)\n",
+ "df_encoded = pd.DataFrame(one_hot_encoded, columns=feature_names)\n",
+ "\n",
+ "df_encoded.index = combined_df[\"epitope\"]\n",
+ "df_encoded_reference = df_encoded[~df_encoded.index.duplicated(keep=\"first\")]\n",
+ "\n",
+ "# Create a new row with all zeros and index name as epitope\n",
+ "df_encoded_epitope = pd.DataFrame()\n",
+ "for epitope in epitopes:\n",
+ " print(epitope)\n",
+ " new_row = pd.DataFrame(index=[epitope],columns=df_encoded_reference.columns)\n",
+ " df_encoded_reference_temp = pd.concat([new_row, df_encoded_reference])\n",
+ "\n",
+ "\n",
+ " if \"MHC_\" + MHC not in df_encoded_reference_temp.columns:\n",
+ " cols_to_zero = [col for col in df_encoded_reference_temp.columns if \"MHC_\" in col]\n",
+ " df_encoded_reference_temp[cols_to_zero] = 0\n",
+ " else:\n",
+ " df_encoded_reference_temp.loc[epitope,[\"MHC_\"+MHC]] = 1\n",
+ "\n",
+ " if \"MHC_class_\" + MHC_class not in df_encoded_reference_temp.columns:\n",
+ " cols_to_zero = [col for col in df_encoded_reference_temp.columns if \"MHC_class_\" in col]\n",
+ " df_encoded_reference_temp[cols_to_zero] = 0\n",
+ " else:\n",
+ " df_encoded_reference_temp.loc[epitope,[\"MHC_class_\"+MHC_class]] = 1\n",
+ "\n",
+ " if \"species_\" + species not in df_encoded_reference_temp.columns:\n",
+ " cols_to_zero = [col for col in df_encoded_reference_temp.columns if \"species_\" in col]\n",
+ " df_encoded_reference_temp[cols_to_zero] = 0\n",
+ " else:\n",
+ " df_encoded_reference_temp.loc[epitope,[\"species_\"+species]] = 1\n",
+ "\n",
+ "\n",
+ " df_encoded_reference_temp = df_encoded_reference_temp.fillna(0)\n",
+ " df_encoded_epitope = pd.concat( [ df_encoded_epitope , df_encoded_reference_temp.iloc[[0]] ] )\n",
+ "\n",
+ "df_encoded_epitope\n",
+ "\n",
+ ""
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 304
+ },
+ "id": "89a8vzOOQ1f8",
+ "outputId": "5989455f-d55b-47b1-f438-7601f56bdd7f"
+ },
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "DIYKGMGPLLATVFKSV\n",
+ "GMGPLLATVFKSV\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ ":22: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
+ " df_encoded_reference_temp = pd.concat([new_row, df_encoded_reference])\n",
+ ":22: FutureWarning: The behavior of DataFrame concatenation with empty or all-NA entries is deprecated. In a future version, this will no longer exclude empty or all-NA columns when determining the result dtypes. To retain the old behavior, exclude the relevant entries before the concat operation.\n",
+ " df_encoded_reference_temp = pd.concat([new_row, df_encoded_reference])\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " MHC_H2-Db MHC_H2-IAb MHC_H2-IEk MHC_H2-Kb MHC_H2-Kd \\\n",
+ "DIYKGMGPLLATVFKSV 0.0 1.0 0.0 0.0 0.0 \n",
+ "GMGPLLATVFKSV 0.0 1.0 0.0 0.0 0.0 \n",
+ "\n",
+ " MHC_H2-Ld MHC_HLA-A*02:01 MHC_HLA-A*08:01 \\\n",
+ "DIYKGMGPLLATVFKSV 0.0 0.0 0.0 \n",
+ "GMGPLLATVFKSV 0.0 0.0 0.0 \n",
+ "\n",
+ " MHC_HLA-A*11:01 MHC_HLA-A*24:02 ... \\\n",
+ "DIYKGMGPLLATVFKSV 0.0 0.0 ... \n",
+ "GMGPLLATVFKSV 0.0 0.0 ... \n",
+ "\n",
+ " MHC_HLA-DQA1:02/DQB1*06:02 MHC_HLA-DRA:01 \\\n",
+ "DIYKGMGPLLATVFKSV 0.0 0.0 \n",
+ "GMGPLLATVFKSV 0.0 0.0 \n",
+ "\n",
+ " MHC_HLA-DRA:01/DRB1:01 MHC_HLA-DRB1*04:01 \\\n",
+ "DIYKGMGPLLATVFKSV 0.0 0.0 \n",
+ "GMGPLLATVFKSV 0.0 0.0 \n",
+ "\n",
+ " MHC_HLA-DRB1*04:05 MHC_HLA-DRB1*07:01 MHC_HLA-DRB1*11:01 \\\n",
+ "DIYKGMGPLLATVFKSV 0.0 0.0 0.0 \n",
+ "GMGPLLATVFKSV 0.0 0.0 0.0 \n",
+ "\n",
+ " MHC_HLA-DRB1:01 MHC_class_MHCII species_MusMusculus \n",
+ "DIYKGMGPLLATVFKSV 0.0 1.0 1.0 \n",
+ "GMGPLLATVFKSV 0.0 1.0 1.0 \n",
+ "\n",
+ "[2 rows x 37 columns]"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " MHC_H2-Db | \n",
+ " MHC_H2-IAb | \n",
+ " MHC_H2-IEk | \n",
+ " MHC_H2-Kb | \n",
+ " MHC_H2-Kd | \n",
+ " MHC_H2-Ld | \n",
+ " MHC_HLA-A*02:01 | \n",
+ " MHC_HLA-A*08:01 | \n",
+ " MHC_HLA-A*11:01 | \n",
+ " MHC_HLA-A*24:02 | \n",
+ " ... | \n",
+ " MHC_HLA-DQA1:02/DQB1*06:02 | \n",
+ " MHC_HLA-DRA:01 | \n",
+ " MHC_HLA-DRA:01/DRB1:01 | \n",
+ " MHC_HLA-DRB1*04:01 | \n",
+ " MHC_HLA-DRB1*04:05 | \n",
+ " MHC_HLA-DRB1*07:01 | \n",
+ " MHC_HLA-DRB1*11:01 | \n",
+ " MHC_HLA-DRB1:01 | \n",
+ " MHC_class_MHCII | \n",
+ " species_MusMusculus | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " GMGPLLATVFKSV | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " ... | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 0.0 | \n",
+ " 1.0 | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
2 rows × 37 columns
\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "dataframe",
+ "variable_name": "df_encoded_epitope"
+ }
+ },
+ "metadata": {},
+ "execution_count": 12
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "############################################### run prediction on epitope ###############################################\n",
+ "\n",
+ "p_test = run_prediction(mat_test_tab,epitope_embeddings,cdr3_embeddings)\n",
+ ""
+ ],
+ "metadata": {
+ "id": "As-zxbIKkmTO"
+ },
+ "execution_count": 13,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "############################################# average predicted probability #############################################\n",
+ "\n",
+ "p_test = p_test.loc[:,[\"epitope\",\"predicted_prob\"]]\n",
+ "p_test_mean = p_test.groupby('epitope').agg('mean')\n",
+ "p_test_mean"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 164
+ },
+ "id": "osqix4p2sTtq",
+ "outputId": "18ffe67c-7470-403e-c30a-bc53ed3bfd14"
+ },
+ "execution_count": 14,
+ "outputs": [
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " predicted_prob\n",
+ "epitope \n",
+ "DIYKGMGPLLATVFKSV 0.880342\n",
+ "GMGPLLATVFKSV 0.604100"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " predicted_prob | \n",
+ "
\n",
+ " \n",
+ " epitope | \n",
+ " | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ " 0.880342 | \n",
+ "
\n",
+ " \n",
+ " GMGPLLATVFKSV | \n",
+ " 0.604100 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "dataframe",
+ "variable_name": "p_test_mean",
+ "summary": "{\n \"name\": \"p_test_mean\",\n \"rows\": 2,\n \"fields\": [\n {\n \"column\": \"epitope\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"GMGPLLATVFKSV\",\n \"DIYKGMGPLLATVFKSV\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"predicted_prob\",\n \"properties\": {\n \"dtype\": \"number\",\n \"std\": 0.19533245581349337,\n \"min\": 0.6041002243846824,\n \"max\": 0.880342032567768,\n \"num_unique_values\": 2,\n \"samples\": [\n 0.6041002243846824,\n 0.880342032567768\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
+ }
+ },
+ "metadata": {},
+ "execution_count": 14
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "################################################ compute Percentile Rank ################################################\n",
+ "from scipy import stats\n",
+ "percentile_rank_all = pd.DataFrame()\n",
+ "for epitope in p_test_mean.index:\n",
+ "\n",
+ " # Calculate the mean predicted probability for the test set\n",
+ " mean_predicted_prob = p_test_mean.loc[epitope,\"predicted_prob\"]\n",
+ "\n",
+ " # Calculate the percentile rank\n",
+ " percentile_rank = round(100 - stats.percentileofscore(p_test_mean_background[\"predicted_prob\"], mean_predicted_prob), 2 )\n",
+ "\n",
+ " df = pd.DataFrame([epitope, percentile_rank]).transpose()\n",
+ " percentile_rank_all = pd.concat( [ percentile_rank_all , df ] )\n",
+ "\n",
+ " # Print the percentile rank\n",
+ " print(f\"Percentile Rank: {percentile_rank}\")\n",
+ "\n",
+ "# Percentile Rank from 0 to 100. the closer to 0 the stronger the predicted TCR recognition.\n",
+ "percentile_rank_all.columns = [\"epitope\",\"Percentile Rank\"]\n",
+ "percentile_rank_all"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 181
+ },
+ "id": "DS5pQUWV26ll",
+ "outputId": "e553d1ba-1c01-4387-b57b-1f018d145890"
+ },
+ "execution_count": 22,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Percentile Rank: 1.25\n",
+ "Percentile Rank: 16.15\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": [
+ " epitope Percentile Rank\n",
+ "0 DIYKGMGPLLATVFKSV 1.25\n",
+ "0 GMGPLLATVFKSV 16.15"
+ ],
+ "text/html": [
+ "\n",
+ " \n",
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " epitope | \n",
+ " Percentile Rank | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " DIYKGMGPLLATVFKSV | \n",
+ " 1.25 | \n",
+ "
\n",
+ " \n",
+ " 0 | \n",
+ " GMGPLLATVFKSV | \n",
+ " 16.15 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "dataframe",
+ "variable_name": "percentile_rank_all",
+ "summary": "{\n \"name\": \"percentile_rank_all\",\n \"rows\": 2,\n \"fields\": [\n {\n \"column\": \"epitope\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 2,\n \"samples\": [\n \"GMGPLLATVFKSV\",\n \"DIYKGMGPLLATVFKSV\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"Percentile Rank\",\n \"properties\": {\n \"dtype\": \"date\",\n \"min\": 1.25,\n \"max\": 16.15,\n \"num_unique_values\": 2,\n \"samples\": [\n 16.15,\n 1.25\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
+ }
+ },
+ "metadata": {},
+ "execution_count": 22
+ }
+ ]
+ }
+ ],
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "machine_shape": "hm"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file