Skip to content

Commit

Permalink
Binary search implemented for python methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Jieran-S committed Feb 26, 2024
1 parent 521ad4a commit 3f7a160
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 20 deletions.
16 changes: 13 additions & 3 deletions method/STAGATE/method_STAGATE.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,22 @@
import numpy as np
import pandas as pd
import scanpy as sc
import sys
# Use tensorflow or pyG
import tensorflow.compat.v1 as tf
tf.compat.v1.disable_eager_execution()
# the location of R (used for the mclust clustering)
import scipy as sp

# import res-n_clust tuning function
import sys
from pathlib import Path

# Add the parent directory of the current file to sys.path
method_dir = Path(__file__).resolve().parent.parent # Navigate two levels up
sys.path.append(str(method_dir))

from search_res import binary_search

def get_anndata(args):
import anndata as ad
X = sp.io.mmread(args.matrix)
Expand Down Expand Up @@ -152,8 +161,9 @@ def get_anndata(args):
adata = sg.mclust_R(adata, used_obsm='STAGATE', num_cluster=n_clusters, random_seed=seed)
label_df = adata.obs[["mclust"]]
elif config["method"] == "louvain":
sc.tl.louvain(adata, resolution=config["res"])
label_df = adata.obs[["louvain"]]
label_df = binary_search(adata, n_clust_target=n_clusters, method="louvain", seed = seed)
#sc.tl.louvain(adata, resolution=config["res"])
#label_df = adata.obs[["louvain"]]

## Write output
out_dir.mkdir(parents=True, exist_ok=True)
Expand Down
20 changes: 17 additions & 3 deletions method/SpaceFlow/method_spaceflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,22 @@ def get_anndata(args):
random.seed(seed)

import scanpy as sc
import anndata as ad
from SpaceFlow import SpaceFlow
import pandas as pd
import warnings
import torch

# import res-n_clust tuning function
import sys
from pathlib import Path

# Add the parent directory of the current file to sys.path
method_dir = Path(__file__).resolve().parent.parent # Navigate two levels up
sys.path.append(str(method_dir))

from search_res import binary_search

use_cuda = torch.cuda.is_available()
device = 1 if use_cuda else 0

Expand All @@ -155,12 +166,15 @@ def get_anndata(args):
nn = 15

# Raise a warning that clustering is based on resolution and not n_clusters
warnings.warn("The `n_clusters` parameter was not used; config['res'] used instead.")
# warnings.warn("The `n_clusters` parameter was not used; config['res'] used instead.")

# Segment the domains given the resolution
sf.segmentation(domain_label_save_filepath=label_file, n_neighbors=nn, resolution=res)
embedding_adata = ad.AnnData(sf.embedding)
sc.pp.neighbors(embedding_adata, n_neighbors=nn, use_rep="X")
label_df = binary_search(embedding_adata, n_clust_target=n_clusters, method="leiden", seed = seed)
# sf.segmentation(domain_label_save_filepath=label_file, n_neighbors=nn, resolution=res)

label_df = pd.DataFrame(sf.domains) # DataFrame with index (cell-id/barcode) and 1 column (label)
# label_df = pd.DataFrame(sf.domains) # DataFrame with index (cell-id/barcode) and 1 column (label)
embedding_df = pd.DataFrame(sf.embedding, index=adata.obs_names) # DataFrame with index (cell-id/barcode) and n columns

## Write output
Expand Down
24 changes: 19 additions & 5 deletions method/scanpy/method_scanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,16 @@ def get_anndata(args):
import scanpy as sc
import json

# import res-n_clust tuning function
import sys
from pathlib import Path

# Add the parent directory of the current file to sys.path
method_dir = Path(__file__).resolve().parent.parent # Navigate two levels up
sys.path.append(str(method_dir))

from search_res import binary_search

# get the json config
with open (args.config, "r") as c:
config = json.load(c)
Expand All @@ -160,7 +170,7 @@ def get_anndata(args):
# throw warning about not using the num_clusters as a parameter, because scanpy uses leiden or louvain and needs the resolution parameter, which is defined in the config.json. There exists a good function to perform an extensive search for the right resolution parameter to define the desired num_clusters, but we still have licensing issues. For more, see the SpaceHack2.0 GitHub issue #139
import warnings

warnings.warn("Scanpy uses leiden/louvain for clustering, which relies on the resolution parameter in the config file. The parameter num_clusters will be ignored.", UserWarning)
# warnings.warn("Scanpy uses leiden/louvain for clustering, which relies on the resolution parameter in the config file. The parameter num_clusters will be ignored.", UserWarning)

# scanpy starts here
if not args.dim_red:
Expand All @@ -173,14 +183,18 @@ def get_anndata(args):

#two options - leiden or loivain
if config['clustering'] == "louvain":
sc.tl.louvain(adata, resolution=config["resolution"], random_state=seed, key_added='louvain')
label_df = binary_search(adata, n_clust_target=n_clusters, method="louvain", seed = seed)
# sc.tl.louvain(adata, resolution=config["resolution"], random_state=seed, key_added='louvain')
elif config['clustering'] == "leiden":
sc.tl.leiden(adata, resolution=config["resolution"], random_state=seed)
label_df = binary_search(adata, n_clust_target=n_clusters, method="leiden", seed = seed)
# sc.tl.leiden(adata, resolution=config["resolution"], random_state=seed)
else:
print("No clustering method defined or your method is not available, performing leiden")
sc.tl.leiden(adata, resolution=config["resolution"], random_state=seed)
label_df = binary_search(adata, n_clust_target=n_clusters, method="leiden", seed = seed)

# sc.tl.leiden(adata, resolution=config["resolution"], random_state=seed)

label_df = adata.obs[["leiden"]]
# label_df = adata.obs[["leiden"]]


# embedding_df = None # optional, DataFrame with index (cell-id/barcode) and n columns
Expand Down
13 changes: 4 additions & 9 deletions method/binary_search.py → method/search_res.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def binary_search(
tolerance = 1e-3,
seed = 2023,
):
"""
"""
Uses binary search to find the resolution parameter that results in the target number of clusters.
Parameters
Expand Down Expand Up @@ -46,21 +46,16 @@ def binary_search(
Returns
------------
y: a pandas dataframe with one clumn denoting the clustering results from best resolution and with barcode as index.
"""
"""
import scanpy as sc
import numpy as np
import warnings

y = None

def do_clustering(res):
match method:
case "louvain":
sc.tl.louvain(adata, resolution=res, random_state=seed)
y = adata.obs[["louvain"]].astype(int)
case "leiden":
sc.tl.leiden(adata, resolution=res, random_state=seed)
y = adata.obs[["leiden"]].astype(int)
getattr(sc.tl, method)(adata, resolution=res, random_state=seed)
y = adata.obs[[method]].astype(int)
n_clust = len(np.unique(y))
return y, n_clust

Expand Down

0 comments on commit 3f7a160

Please sign in to comment.