From 4a313be99cbf98b058a31695181ac5c59804325a Mon Sep 17 00:00:00 2001 From: Andreas Schuh Date: Wed, 19 Jul 2023 21:06:51 +0000 Subject: [PATCH] [wip] Demo basic use of deepali.spatial for learning image registration --- .../pairwise-registration-intro.ipynb | 172 +++++++++++++++++- 1 file changed, 168 insertions(+), 4 deletions(-) diff --git a/docs/tutorials/pairwise-registration-intro.ipynb b/docs/tutorials/pairwise-registration-intro.ipynb index 5a53861..717e2dd 100644 --- a/docs/tutorials/pairwise-registration-intro.ipynb +++ b/docs/tutorials/pairwise-registration-intro.ipynb @@ -43,12 +43,12 @@ "outputs": [], "source": [ "try:\n", - " from deepali.utils.cli import cuda_visible_devices\n", + " import deepali.core.functional as U\n", "except ImportError:\n", " if not os.getenv(\"COLAB_RELEASE_TAG\"):\n", " raise\n", " !git clone https://github.com/BioMedIA/deepali.git && pip install ./deepali\n", - " from deepali.utils.cli import cuda_visible_devices" + " import deepali.core.functional as U" ] }, { @@ -65,10 +65,10 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"4\"\n", "\n", "# Use first device specified in CUDA_VISIBLE_DEVICES if CUDA is available\n", - "device = torch.device(\"cuda:0\" if torch.cuda.is_available() and cuda_visible_devices() else \"cpu\")" + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")" ] }, { @@ -934,6 +934,170 @@ "_ = invertible_registration_figure(target_cshape, source_circle, circle_to_c_transform)" ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Learning registration\n", + "\n", + "We can utilize the same `deepali.spatial` modules that we used for traditional image registration to realize a learning based approach. Previously, we optimized the parameters of the chosen spatial transform directly and individually for each given pair of images to register. In learned image registration, we instead employ a regression model which infers the parameters of the spatial transform from a given input, i.e., the fixed target and moving source images in case of pairwise image registration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from deepali.networks.resnet import ResNet\n", + "\n", + "\n", + "model = ResNet(\n", + " spatial_dims=2,\n", + " in_channels=2,\n", + " stride=(1, 2, 2),\n", + " num_blocks=(3, 4, 6),\n", + " num_channels=(8, 16, 32),\n", + " num_layers=2,\n", + " kernel_size=3,\n", + " expansion=1,\n", + " norm=None,\n", + " acti=\"relu\",\n", + ")\n", + "model.add_module(\"flatten\", torch.nn.Flatten())\n", + "\n", + "in_features = cast(Tensor, model.forward(torch.rand((1, 2) + target.shape[1:]))).numel()\n", + "out_params = 1 # rotation angle\n", + "\n", + "model.add_module(\"params\", torch.nn.Linear(in_features, out_params))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "\n", + "\n", + "class ImagePairs(Dataset):\n", + " r\"\"\"Dataset producing an endless stream of randomly paired images from a map-style dataset.\"\"\"\n", + "\n", + " def __init__(self, mnist: MNIST, digit: int, n: int = 0) -> None:\n", + " images = []\n", + " for image, label in mnist:\n", + " if label != digit:\n", + " continue\n", + " images.append(image)\n", + " if n > 0 and len(images) == n:\n", + " break\n", + " self.images = images\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.images) ** 2\n", + "\n", + " def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:\n", + " n = len(self.images)\n", + " i = index // n\n", + " j = index % n\n", + " target = self.images[i]\n", + " source = self.images[j]\n", + " return target, source\n", + "\n", + "\n", + "dataset = ImagePairs(mnist, 9, 100)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader, random_split\n", + "\n", + "\n", + "epochs = 100\n", + "steps = -1\n", + "batch_size = 10\n", + "\n", + "\n", + "generator = torch.Generator().manual_seed(42)\n", + "train_dataset, val_dataset = random_split(dataset, [0.8, 0.2], generator=generator)\n", + "\n", + "train_dataloader = DataLoader(\n", + " train_dataset,\n", + " batch_size=batch_size,\n", + " drop_last=True,\n", + " shuffle=True,\n", + ")\n", + "\n", + "val_dataloader = DataLoader(train_dataset, batch_size=1)\n", + "\n", + "transform = spatial.EulerRotation(grid, groups=batch_size, params=model)\n", + "transformer = spatial.ImageTransformer(transform).to(device)\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-2)\n", + "\n", + "bar_format = \"{l_bar}{bar}{rate_fmt}{postfix}\"\n", + "\n", + "total_steps = 0\n", + "for epoch in range(epochs):\n", + " with tqdm(total=len(train_dataloader), bar_format=bar_format) as pbar:\n", + " pbar.set_description(f\"Epoch {epoch}\")\n", + " transform.train()\n", + " for step, (target_batch, source_batch) in enumerate(train_dataloader):\n", + " assert isinstance(target_batch, Tensor)\n", + " assert isinstance(source_batch, Tensor)\n", + " target_batch = target_batch.to(device=device, non_blocking=True)\n", + " source_batch = source_batch.to(device=device, non_blocking=True)\n", + " input = torch.cat([target_batch, source_batch], dim=1).div(255)\n", + " transformer.condition_(input)\n", + " warped_batch = transformer(source_batch)\n", + " loss = sim_loss(warped_batch, target_batch)\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_steps += 1\n", + " pbar.set_postfix(loss=loss.item())\n", + " pbar.update()\n", + " if steps > 0 and step >= steps:\n", + " break\n", + " transform.eval()\n", + " with torch.inference_mode():\n", + " val_loss = torch.tensor(0.0, device=device)\n", + " for target_batch, source_batch in val_dataloader:\n", + " assert isinstance(target_batch, Tensor)\n", + " assert isinstance(source_batch, Tensor)\n", + " target_batch = target_batch.to(device=device, non_blocking=True)\n", + " source_batch = source_batch.to(device=device, non_blocking=True)\n", + " input = torch.cat([target_batch, source_batch], dim=1).div(255)\n", + " transformer.condition_(input)\n", + " warped_batch = transformer(source_batch)\n", + " loss = sim_loss(warped_batch, target_batch)\n", + " val_loss = val_loss.add_(loss)\n", + " val_loss = val_loss.div_(len(val_dataloader))\n", + " pbar.set_postfix(val_loss=val_loss.item())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transform.eval()\n", + "with torch.inference_mode():\n", + " warped = transformer(source.to(device)).cpu()\n", + "\n", + "fig, axes = plt.subplots(1, 3, figsize=(12, 4), tight_layout=True)\n", + "\n", + "imshow(target, \"target\", ax=axes[0])\n", + "imshow(warped, \"warped\", ax=axes[1])\n", + "imshow(source, \"source\", ax=axes[2])" + ] + }, { "cell_type": "code", "execution_count": null,