Skip to content

Commit

Permalink
Created DLICV pipelien file for better import | arguments fixed and w…
Browse files Browse the repository at this point in the history
…orking with just the 3 basic(i, o, d)
  • Loading branch information
spirosmaggioros committed Dec 11, 2024
1 parent 0f456ce commit ed2f01a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 42 deletions.
21 changes: 19 additions & 2 deletions DLICV/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,26 @@ def main() -> None:
)

args = parser.parse_args()
args.f = [0]

run_pipeline(parser, args)
run_pipeline(
args.in_dir,
args.out_dir,
args.device,
args.clear_cache,
args.d,
args.part_id,
args.num_parts,
args.step_size,
args.disable_tta,
args.verbose,
args.disable_progress_bar,
args.chk,
args.save_probabilities,
args.continue_prediction,
args.npp,
args.nps,
args.prev_stage_predictions,
)


if __name__ == "__main__":
Expand Down
97 changes: 57 additions & 40 deletions DLICV/dlicv_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,70 @@
import argparse
import json
import os
import shutil
import sys
from pathlib import Path
from typing import Optional

import torch

from .utils import prepare_data_folder, rename_and_copy_files


def run_pipeline(parser: argparse.ArgumentParser, args: argparse.Namespace) -> None:
if args.clear_cache:
def run_pipeline(
in_dir: str,
out_dir: str,
device: str,
clear_cache: bool = False,
d: str = "901",
part_id: int = 0,
num_parts: int = 1,
step_size: float = 0.5,
disable_tta: bool = False,
verbose: bool = False,
disable_progress_bar: bool = False,
chk: bool = False,
save_probabilities: bool = False,
continue_prediction: bool = False,
npp: int = 2,
nps: int = 2,
prev_stage_predictions: Optional[str] = None,
) -> None:
f = [0]
if clear_cache:
shutil.rmtree(os.path.join(Path(__file__).parent, "nnunet_results"))
shutil.rmtree(os.path.join(Path(__file__).parent, ".cache"))
if not args.in_dir or not args.out_dir:
if not in_dir or not out_dir:
print("Cache cleared and missing either -i / -o. Exiting.")
sys.exit(0)

if not args.in_dir or not args.out_dir:
parser.error("The following arguments are required: -i, -o")
if not in_dir or not out_dir:
print("The following arguments are required: -i, -o")
sys.exit(0)

# data conversion
src_folder = args.in_dir # input folder
if not os.path.exists(args.out_dir): # create output folder if it does not exist
os.makedirs(args.out_dir)
src_folder = in_dir # input folder
if not os.path.exists(out_dir): # create output folder if it does not exist
os.makedirs(out_dir)

des_folder = os.path.join(args.out_dir, "renamed_image")
des_folder = os.path.join(out_dir, "renamed_image")

# check if -i argument is a folder, list (csv), or a single file (nii.gz)
if os.path.isdir(args.in_dir): # if args.i is a directory
src_folder = args.in_dir
if os.path.isdir(in_dir): # if args.i is a directory
src_folder = in_dir
prepare_data_folder(des_folder)
rename_dic, rename_back_dict = rename_and_copy_files(src_folder, des_folder)
datalist_file = os.path.join(des_folder, "renaming.json")
with open(datalist_file, "w", encoding="utf-8") as f:
json.dump(rename_dic, f, ensure_ascii=False, indent=4)
with open(datalist_file, "w", encoding="utf-8") as ff:
json.dump(rename_dic, ff, ensure_ascii=False, indent=4)
print(f"Renaming dic is saved to {datalist_file}")

model_folder = os.path.join(
Path(__file__).parent,
"nnunet_results",
"Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/"
% (args.d, args.d),
"Dataset%s_Task%s_dlicv/nnUNetTrainer__nnUNetPlans__3d_fullres/" % (d, d),
)

if args.clear_cache:
if clear_cache:
shutil.rmtree(os.path.join(Path(__file__).parent, "nnunet_results"))
shutil.rmtree(os.path.join(Path(__file__).parent, ".cache"))

Expand All @@ -62,25 +81,25 @@ def run_pipeline(parser: argparse.ArgumentParser, args: argparse.Namespace) -> N
else:
print("Loading the model...")

prepare_data_folder(args.out_dir)
prepare_data_folder(out_dir)

# Check for invalid arguments - advise users to see nnUNetv2 documentation
assert args.part_id < args.num_parts, "See nnUNetv2_predict -h."
assert part_id < num_parts, "See nnUNetv2_predict -h."

assert args.device in [
assert device in [
"cpu",
"cuda",
"mps",
], f"-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}."
], f"-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}."

if args.device == "cpu":
if device == "cpu":
import multiprocessing

torch.set_num_threads(
multiprocessing.cpu_count() // 2
) # use half of the threads (better for PC)
device = torch.device("cpu")
elif args.device == "cuda":
elif device == "cuda":
# multithreading in torch doesn't help nnU-Netv2 if run on GPU
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
Expand All @@ -99,36 +118,34 @@ def run_pipeline(parser: argparse.ArgumentParser, args: argparse.Namespace) -> N

# Initialize nnUnetPredictor
predictor = nnUNetPredictor(
tile_step_size=args.step_size,
tile_step_size=step_size,
use_gaussian=True,
use_mirroring=not args.disable_tta,
use_mirroring=not disable_tta,
perform_everything_on_device=True,
device=device,
verbose=args.verbose,
verbose_preprocessing=args.verbose,
allow_tqdm=not args.disable_progress_bar,
verbose=verbose,
verbose_preprocessing=verbose,
allow_tqdm=not disable_progress_bar,
)

# Retrieve the model and its weight
predictor.initialize_from_trained_model_folder(
model_folder, args.f, checkpoint_name=args.chk
)
predictor.initialize_from_trained_model_folder(model_folder, f, checkpoint_name=chk)

# Final prediction
predictor.predict_from_files(
des_folder,
args.out_dir,
save_probabilities=args.save_probabilities,
overwrite=not args.continue_prediction,
num_processes_preprocessing=args.npp,
num_processes_segmentation_export=args.nps,
folder_with_segs_from_prev_stage=args.prev_stage_predictions,
num_parts=args.num_parts,
part_id=args.part_id,
out_dir,
save_probabilities=save_probabilities,
overwrite=not continue_prediction,
num_processes_preprocessing=npp,
num_processes_segmentation_export=nps,
folder_with_segs_from_prev_stage=prev_stage_predictions,
num_parts=num_parts,
part_id=part_id,
)

# After prediction, convert the image name back to original
files_folder = args.out_dir
files_folder = out_dir

for filename in os.listdir(files_folder):
if filename.endswith(".nii.gz"):
Expand Down

0 comments on commit ed2f01a

Please sign in to comment.