Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo basic use of deepali.spatial for learning image registration #91

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 168 additions & 4 deletions docs/tutorials/pairwise-registration-intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@
"outputs": [],
"source": [
"try:\n",
" from deepali.utils.cli import cuda_visible_devices\n",
" import deepali.core.functional as U\n",
"except ImportError:\n",
" if not os.getenv(\"COLAB_RELEASE_TAG\"):\n",
" raise\n",
" !git clone https://github.com/BioMedIA/deepali.git && pip install ./deepali\n",
" from deepali.utils.cli import cuda_visible_devices"
" import deepali.core.functional as U"
]
},
{
Expand All @@ -65,10 +65,10 @@
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\"\n",
"\n",
"# Use first device specified in CUDA_VISIBLE_DEVICES if CUDA is available\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() and cuda_visible_devices() else \"cpu\")"
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")"
]
},
{
Expand Down Expand Up @@ -934,6 +934,170 @@
"_ = invertible_registration_figure(target_cshape, source_circle, circle_to_c_transform)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Learning registration\n",
"\n",
"We can utilize the same `deepali.spatial` modules that we used for traditional image registration to realize a learning based approach. Previously, we optimized the parameters of the chosen spatial transform directly and individually for each given pair of images to register. In learned image registration, we instead employ a regression model which infers the parameters of the spatial transform from a given input, i.e., the fixed target and moving source images in case of pairwise image registration."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from deepali.networks.resnet import ResNet\n",
"\n",
"\n",
"model = ResNet(\n",
" spatial_dims=2,\n",
" in_channels=2,\n",
" stride=(1, 2, 2),\n",
" num_blocks=(3, 4, 6),\n",
" num_channels=(8, 16, 32),\n",
" num_layers=2,\n",
" kernel_size=3,\n",
" expansion=1,\n",
" norm=None,\n",
" acti=\"relu\",\n",
")\n",
"model.add_module(\"flatten\", torch.nn.Flatten())\n",
"\n",
"in_features = cast(Tensor, model.forward(torch.rand((1, 2) + target.shape[1:]))).numel()\n",
"out_params = 1 # rotation angle\n",
"\n",
"model.add_module(\"params\", torch.nn.Linear(in_features, out_params))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class ImagePairs(Dataset):\n",
" r\"\"\"Dataset producing an endless stream of randomly paired images from a map-style dataset.\"\"\"\n",
"\n",
" def __init__(self, mnist: MNIST, digit: int, n: int = 0) -> None:\n",
" images = []\n",
" for image, label in mnist:\n",
" if label != digit:\n",
" continue\n",
" images.append(image)\n",
" if n > 0 and len(images) == n:\n",
" break\n",
" self.images = images\n",
"\n",
" def __len__(self) -> int:\n",
" return len(self.images) ** 2\n",
"\n",
" def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:\n",
" n = len(self.images)\n",
" i = index // n\n",
" j = index % n\n",
" target = self.images[i]\n",
" source = self.images[j]\n",
" return target, source\n",
"\n",
"\n",
"dataset = ImagePairs(mnist, 9, 100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader, random_split\n",
"\n",
"\n",
"epochs = 100\n",
"steps = -1\n",
"batch_size = 10\n",
"\n",
"\n",
"generator = torch.Generator().manual_seed(42)\n",
"train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=generator)\n",
"\n",
"train_dataloader = DataLoader(\n",
" train_dataset,\n",
" batch_size=batch_size,\n",
" drop_last=True,\n",
" shuffle=True,\n",
")\n",
"\n",
"val_dataloader = DataLoader(train_dataset, batch_size=1)\n",
"\n",
"transform = spatial.EulerRotation(grid, groups=batch_size, params=model)\n",
"transformer = spatial.ImageTransformer(transform).to(device)\n",
"optimizer = optim.Adam(model.parameters(), lr=1e-2)\n",
"\n",
"bar_format = \"{l_bar}{bar}{rate_fmt}{postfix}\"\n",
"\n",
"total_steps = 0\n",
"for epoch in range(epochs):\n",
" with tqdm(total=len(train_dataloader), bar_format=bar_format) as pbar:\n",
" pbar.set_description(f\"Epoch {epoch}\")\n",
" transform.train()\n",
" for step, (target_batch, source_batch) in enumerate(train_dataloader):\n",
" assert isinstance(target_batch, Tensor)\n",
" assert isinstance(source_batch, Tensor)\n",
" target_batch = target_batch.to(device=device, non_blocking=True)\n",
" source_batch = source_batch.to(device=device, non_blocking=True)\n",
" input = torch.cat([target_batch, source_batch], dim=1).div(255)\n",
" transformer.condition_(input)\n",
" warped_batch = transformer(source_batch)\n",
" loss = sim_loss(warped_batch, target_batch)\n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" total_steps += 1\n",
" pbar.set_postfix(loss=loss.item())\n",
" pbar.update()\n",
" if steps > 0 and step >= steps:\n",
" break\n",
" transform.eval()\n",
" with torch.inference_mode():\n",
" val_loss = torch.tensor(0.0, device=device)\n",
" for target_batch, source_batch in val_dataloader:\n",
" assert isinstance(target_batch, Tensor)\n",
" assert isinstance(source_batch, Tensor)\n",
" target_batch = target_batch.to(device=device, non_blocking=True)\n",
" source_batch = source_batch.to(device=device, non_blocking=True)\n",
" input = torch.cat([target_batch, source_batch], dim=1).div(255)\n",
" transformer.condition_(input)\n",
" warped_batch = transformer(source_batch)\n",
" loss = sim_loss(warped_batch, target_batch)\n",
" val_loss = val_loss.add_(loss)\n",
" val_loss = val_loss.div_(len(val_dataloader))\n",
" pbar.set_postfix(val_loss=val_loss.item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"transform.eval()\n",
"with torch.inference_mode():\n",
" warped = transformer(source.to(device)).cpu()\n",
"\n",
"fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n",
"\n",
"imshow(target, \"target\", ax=axes[0])\n",
"imshow(warped, \"warped\", ax=axes[1])\n",
"imshow(source, \"source\", ax=axes[2])"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down