https://github.com/turleyjm/cell-division-dl-plugin
Tip revision: f6c3290670f23524d47ccddc873ef5673eda4b3c authored by turleyjm on 11 July 2024, 02:09:47 UTC
Update README.md
Update README.md
Tip revision: f6c3290
trainingUNetOrientation.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"from cgitb import reset\n",
"import torch\n",
"import albumentations as A\n",
"from albumentations.pytorch import ToTensorV2\n",
"from tqdm import tqdm\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchvision\n",
"from torch.utils.data import DataLoader\n",
"import os\n",
"from PIL import Image\n",
"from torch.utils.data import Dataset\n",
"import numpy as np\n",
"import skimage as sm\n",
"import skimage.io\n",
"from matplotlib import pyplot as plt\n",
"import tifffile\n",
"import timm\n",
"from fastai.vision.all import *\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"# Hyperparameters\n",
"\n",
"LEARNING_RATE = 1e-4\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"BATCH_SIZE = 4\n",
"NUM_EPOCHS = 5\n",
"NUM_WORKERS = 2\n",
"IMAGE_HEIGHT = 512\n",
"IMAGE_WIDTH = 512\n",
"PIN_MEMORY = True\n",
"LOAD_MODEL = True\n",
"TRAIN_IMG_DIR = \"dat_orientation/train_images/\"\n",
"TRAIN_MASK_DIR = \"dat_orientation/train_masks/\"\n",
"VAL_IMG_DIR = \"dat_orientation/val_images/\"\n",
"VAL_MASK_DIR = \"dat_orientation/val_masks/\""
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"# defines the dataloader\n",
"\n",
"class OriDataset(Dataset):\n",
" def __init__(self, image_dir, mask_dir, transform=None):\n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.transform = transform\n",
" filenames = os.listdir(image_dir)\n",
" filenames.sort()\n",
" if \".DS_Store\" in filenames:\n",
" filenames.remove(\".DS_Store\")\n",
" self.images = filenames\n",
"\n",
" def __len__(self):\n",
" return len(self.images)\n",
"\n",
" # gets both the 10 frame images and corresponding mask\n",
" def __getitem__(self, index):\n",
" img_path = os.path.join(self.image_dir, self.images[index])\n",
" mask_path = os.path.join(\n",
" self.mask_dir, self.images[index].replace(\".tif\", \"_mask.tif\"))\n",
" image = sm.io.imread(img_path).astype(np.float32)\n",
" mask = np.array(Image.open(mask_path), dtype=np.float32)\n",
" mask0 = mask\n",
" mask[mask == 255] = 1\n",
" images = torch.tensor(image/256).float()\n",
"\n",
" if self.transform is not None:\n",
" # Normilises and transforms the images and masks \n",
" transformed = self.transform(image=image[0], image0=image[1], image1=image[2], image2=image[3], \n",
" image3=image[4], image4=image[5], image5=image[6], image6=image[7], \n",
" image7=image[8], image8=image[9], mask=mask)\n",
" images[0] = transformed[\"image\"]\n",
" images[1] = transformed[\"image0\"]\n",
" images[2] = transformed[\"image1\"]\n",
" images[3] = transformed[\"image2\"]\n",
" images[4] = transformed[\"image3\"]\n",
" images[5] = transformed[\"image4\"]\n",
" images[6] = transformed[\"image5\"]\n",
" images[7] = transformed[\"image6\"]\n",
" images[8] = transformed[\"image7\"]\n",
" images[9] = transformed[\"image8\"]\n",
"\n",
" mask = transformed[\"mask\"]\n",
"\n",
" # saves the mask and image before and after transform to \n",
" # check transforms are correctly functioning\n",
"\n",
" # save_transform(image, mask0, transformed)\n",
"\n",
" return images, mask\n",
"\n",
"# saves the before and after transform by the augmentations\n",
"def save_transform(image, mask0, transformed):\n",
"\n",
" result = np.zeros([10, 1034, 1034])\n",
" result[:, 0:512, 0:512] = image\n",
" result[0, 0:512, 522:] = np.array(transformed[\"image\"])*255\n",
" result[1, 0:512, 522:] = np.array(transformed[\"image0\"])*255\n",
" result[2, 0:512, 522:] = np.array(transformed[\"image1\"])*255\n",
" result[3, 0:512, 522:] = np.array(transformed[\"image2\"])*255\n",
" result[4, 0:512, 522:] = np.array(transformed[\"image3\"])*255\n",
" result[5, 0:512, 522:] = np.array(transformed[\"image4\"])*255\n",
" result[6, 0:512, 522:] = np.array(transformed[\"image5\"])*255\n",
" result[7, 0:512, 522:] = np.array(transformed[\"image6\"])*255\n",
" result[8, 0:512, 522:] = np.array(transformed[\"image7\"])*255\n",
" result[9, 0:512, 522:] = np.array(transformed[\"image8\"])*255\n",
"\n",
" result[:, 522:, 0:512] = mask0*255\n",
" result[:, 522:, 522:] = np.array(transformed[\"mask\"])*255\n",
"\n",
" result = np.asarray(result, \"uint8\")\n",
" tifffile.imwrite(f\"transformResults/transform.tif\", result)\n",
"\n",
"\n",
"# util\n",
"\n",
"# save model parameters\n",
"def save_checkpoint(state, filename=\"models/UNetOrientation_new.pth.tar\"):\n",
" print(\"=> Saving checkpoint\")\n",
" torch.save(state, filename)\n",
"\n",
"# load model parameters\n",
"def load_checkpoint(checkpoint, model):\n",
" print(\"=> Loading checkpoint\")\n",
" model.load_state_dict(checkpoint[\"state_dict\"])\n",
"\n",
"# Make the dataloader \n",
"def get_loaders(\n",
" train_dir,\n",
" train_maskdir,\n",
" val_dir,\n",
" val_maskdir,\n",
" batch_size,\n",
" train_transform,\n",
" val_transform,\n",
" num_workers=4,\n",
" pin_memory=True\n",
"):\n",
" train_ds = OriDataset(\n",
" image_dir=train_dir,\n",
" mask_dir=train_maskdir,\n",
" transform=train_transform,\n",
" )\n",
"\n",
" train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" pin_memory=pin_memory,\n",
" shuffle=True,\n",
" )\n",
"\n",
" val_ds = OriDataset(\n",
" image_dir=val_dir,\n",
" mask_dir=val_maskdir,\n",
" transform=val_transform,\n",
" )\n",
"\n",
" val_loader = DataLoader(\n",
" val_ds,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" pin_memory=pin_memory,\n",
" shuffle=False\n",
" )\n",
"\n",
" return train_loader, val_loader\n",
"\n",
"# define metric to assess model performance \n",
"def check_accuracy(loader, model, device=\"cuda\"):\n",
" num_correct = 0\n",
" num_pixels = 0\n",
" dice_score = 0\n",
" model.eval()\n",
" loop = tqdm(loader)\n",
"\n",
" with torch.no_grad():\n",
" for batch_idx, (x, y) in enumerate(loop):\n",
" x = x.to(device)\n",
" y = y.to(device).unsqueeze(1)\n",
" preds = torch.sigmoid(model(x))\n",
" preds = (preds > 0.5).float()\n",
" num_correct += (preds == y).sum()\n",
" num_pixels += torch.numel(preds)\n",
" dice_score += (2 * (preds * y).sum()) / (\n",
" (preds + y).sum() + 1e-8\n",
" )\n",
"\n",
" print(\n",
" f\"Accuracy {num_correct/num_pixels*100}%\"\n",
" )\n",
" print(f\"Dice score {dice_score/len(loader)}\")\n",
" model.train()\n",
"\n",
"# saves the ground truth with the model prediciton to folder saved_images\n",
"def save_predictions_as_imgs(loader, model, folder=\"saved_images/\", device=\"cuda\"):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for idx, (x, y) in enumerate(loader):\n",
" x = x.to(device)\n",
" preds = torch.sigmoid(model(x))\n",
" preds = (preds > 0.5).float()\n",
" for i in range(preds.shape[0]):\n",
" torchvision.utils.save_image(\n",
" preds[i], f\"{folder}pred_{i}.png\"\n",
" )\n",
" torchvision.utils.save_image(\n",
" y.unsqueeze(1)[i], f\"{folder}img_{i}.png\")\n",
"\n",
" break\n",
"\n",
" model.train()\n",
"\n",
"# train the model and update parameters of model\n",
"def train_fn(loader, model, optimizer, loss_fn, scaler):\n",
" loop = tqdm(loader)\n",
"\n",
" for batch_idx, (data, targets) in enumerate(loop):\n",
" data = data.to(device=DEVICE)\n",
" targets = torch.unsqueeze(targets, 1).to(device=DEVICE)\n",
"\n",
" # forward\n",
" with torch.cuda.amp.autocast():\n",
" predictions = model(data)\n",
" loss = loss_fn(predictions, targets)\n",
"\n",
" # backward\n",
" optimizer.zero_grad()\n",
" scaler.scale(loss).backward()\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
"\n",
" # update tqdm loop\n",
" loop.set_postfix(loss=loss.item())\n",
"\n",
"# load and train the deep learning model\n",
"def main():\n",
" target10 = {'image0': 'image', 'image1': 'image', 'image2': 'image', 'image3': 'image', \n",
" 'image4': 'image', 'image5': 'image', 'image6': 'image', 'image7': 'image', \n",
" 'image8': 'image', 'image9': 'image', 'mask': 'mask'}\n",
" # augmentations for training model\n",
" train_transform = A.Compose(\n",
" [\n",
" A.Rotate(limit=35, p=1.0),\n",
" A.HorizontalFlip(p=0.5),\n",
" A.VerticalFlip(p=0.5),\n",
" A.GaussianBlur(blur_limit=(3, 5), p=0.3),\n",
" A.Normalize(\n",
" mean=0,\n",
" std=1,\n",
" max_pixel_value=255.0,\n",
" ),\n",
" A.RandomBrightnessContrast(p=0.3),\n",
" A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.3),\n",
" ToTensorV2(),\n",
" ],\n",
" additional_targets=target10,\n",
" )\n",
" # augmentations for validation data\n",
" val_transform = A.Compose(\n",
" [\n",
" A.Normalize(\n",
" mean=0,\n",
" std=1,\n",
" max_pixel_value=255.0,\n",
" ),\n",
" ToTensorV2(),\n",
" ],\n",
" additional_targets=target10,\n",
" )\n",
" \n",
" # make the UNetOrientation model \n",
" resnet = timm.create_model(\"resnet34\", pretrained=True)\n",
" resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(\n",
" 7, 7), stride=(2, 2), padding=(3, 3), bias=False) # change model first layer to have 10 features \n",
"\n",
" m = resnet\n",
" m = nn.Sequential(*list(m.children())[:-2])\n",
" model = DynamicUnet(m, 1, (120, 120), norm_type=None).to(DEVICE)\n",
"\n",
" loss_fn = nn.BCEWithLogitsLoss() # if out_channels > 1 => cross entropy loss\n",
" optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(\n",
" 0.9, 0.999), eps=1e-08) # adam learner\n",
"\n",
" # Make the dataloader \n",
" train_loader, val_loader = get_loaders(\n",
" TRAIN_IMG_DIR,\n",
" TRAIN_MASK_DIR,\n",
" VAL_IMG_DIR,\n",
" VAL_MASK_DIR,\n",
" BATCH_SIZE,\n",
" train_transform, # train_transform\n",
" val_transform, # val_transform\n",
" NUM_WORKERS,\n",
" PIN_MEMORY,\n",
" )\n",
" \n",
" # Load training model if one avalable \n",
" if LOAD_MODEL:\n",
" load_checkpoint(torch.load(\"models/UNetOrientation.pth.tar\"), model)\n",
"# save_predictions_as_imgs(\n",
"# val_loader, model, folder=\"saved_images/\", device=DEVICE)\n",
"# check_accuracy(val_loader, model, device=DEVICE)\n",
"\n",
" scaler = torch.cuda.amp.GradScaler()\n",
"\n",
" for epoch in range(NUM_EPOCHS):\n",
" # train model\n",
" train_fn(train_loader, model, optimizer, loss_fn, scaler)\n",
"\n",
" # save model\n",
" checkpoint = {\n",
" \"state_dict\": model.state_dict(),\n",
" \"optimizer\": optimizer.state_dict(),\n",
" }\n",
" save_checkpoint(checkpoint)\n",
"\n",
" # check accuracy\n",
" check_accuracy(val_loader, model, device=DEVICE)\n",
" save_predictions_as_imgs(\n",
" val_loader, model, folder=\"saved_images/\", device=DEVICE)\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/42 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 42/42 [00:22<00:00, 1.85it/s, loss=0.242]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.09it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 89.76678466796875%\n",
"Dice score 0.6154643297195435\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.84it/s, loss=0.234]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.28it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 90.25120544433594%\n",
"Dice score 0.6295679211616516\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:40<00:00, 1.05it/s, loss=0.189]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:24<00:00, 2.27s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 91.2807846069336%\n",
"Dice score 0.6977839469909668\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:52<00:00, 1.26s/it, loss=0.205]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.43it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 93.46248626708984%\n",
"Dice score 0.7872283458709717\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.85it/s, loss=0.135]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.14it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 94.22380828857422%\n",
"Dice score 0.8135465383529663\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"# Train model \n",
"LEARNING_RATE = 1e-4\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/42 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 42/42 [00:22<00:00, 1.89it/s, loss=0.103] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.08it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.23346710205078%\n",
"Dice score 0.8492602109909058\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.86it/s, loss=0.104] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.14it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.6044921875%\n",
"Dice score 0.8634843230247498\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.84it/s, loss=0.0797]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.27it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.51014709472656%\n",
"Dice score 0.8615202903747559\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.84it/s, loss=0.112] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:15<00:00, 1.45s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.40090942382812%\n",
"Dice score 0.8561145663261414\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:58<00:00, 1.40s/it, loss=0.0815]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.07it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.71009826660156%\n",
"Dice score 0.8668902516365051\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"LEARNING_RATE = 5e-5\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/42 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 42/42 [00:22<00:00, 1.86it/s, loss=0.106] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.21it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.67807006835938%\n",
"Dice score 0.865170955657959\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.86it/s, loss=0.072] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.14it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.71603393554688%\n",
"Dice score 0.8670840263366699\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.83it/s, loss=0.0973]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 4.96it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.7632827758789%\n",
"Dice score 0.8682851195335388\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.84it/s, loss=0.103] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.36it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.76087951660156%\n",
"Dice score 0.8681668639183044\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 42/42 [00:22<00:00, 1.86it/s, loss=0.0923]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 11/11 [00:02<00:00, 5.26it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 95.76690673828125%\n",
"Dice score 0.8683835864067078\n"
]
}
],
"source": [
"LEARNING_RATE = 1e-5\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 60, 60] 31,360\n",
" BatchNorm2d-2 [-1, 64, 60, 60] 128\n",
" ReLU-3 [-1, 64, 60, 60] 0\n",
" MaxPool2d-4 [-1, 64, 30, 30] 0\n",
" Conv2d-5 [-1, 64, 30, 30] 36,864\n",
" BatchNorm2d-6 [-1, 64, 30, 30] 128\n",
" Identity-7 [-1, 64, 30, 30] 0\n",
" ReLU-8 [-1, 64, 30, 30] 0\n",
" Identity-9 [-1, 64, 30, 30] 0\n",
" Conv2d-10 [-1, 64, 30, 30] 36,864\n",
" BatchNorm2d-11 [-1, 64, 30, 30] 128\n",
" ReLU-12 [-1, 64, 30, 30] 0\n",
" BasicBlock-13 [-1, 64, 30, 30] 0\n",
" Conv2d-14 [-1, 64, 30, 30] 36,864\n",
" BatchNorm2d-15 [-1, 64, 30, 30] 128\n",
" Identity-16 [-1, 64, 30, 30] 0\n",
" ReLU-17 [-1, 64, 30, 30] 0\n",
" Identity-18 [-1, 64, 30, 30] 0\n",
" Conv2d-19 [-1, 64, 30, 30] 36,864\n",
" BatchNorm2d-20 [-1, 64, 30, 30] 128\n",
" ReLU-21 [-1, 64, 30, 30] 0\n",
" BasicBlock-22 [-1, 64, 30, 30] 0\n",
" Conv2d-23 [-1, 64, 30, 30] 36,864\n",
" BatchNorm2d-24 [-1, 64, 30, 30] 128\n",
" Identity-25 [-1, 64, 30, 30] 0\n",
" ReLU-26 [-1, 64, 30, 30] 0\n",
" Identity-27 [-1, 64, 30, 30] 0\n",
" Conv2d-28 [-1, 64, 30, 30] 36,864\n",
" BatchNorm2d-29 [-1, 64, 30, 30] 128\n",
" ReLU-30 [-1, 64, 30, 30] 0\n",
" BasicBlock-31 [-1, 64, 30, 30] 0\n",
" Conv2d-32 [-1, 128, 15, 15] 73,728\n",
" BatchNorm2d-33 [-1, 128, 15, 15] 256\n",
" Identity-34 [-1, 128, 15, 15] 0\n",
" ReLU-35 [-1, 128, 15, 15] 0\n",
" Identity-36 [-1, 128, 15, 15] 0\n",
" Conv2d-37 [-1, 128, 15, 15] 147,456\n",
" BatchNorm2d-38 [-1, 128, 15, 15] 256\n",
" Conv2d-39 [-1, 128, 15, 15] 8,192\n",
" BatchNorm2d-40 [-1, 128, 15, 15] 256\n",
" ReLU-41 [-1, 128, 15, 15] 0\n",
" BasicBlock-42 [-1, 128, 15, 15] 0\n",
" Conv2d-43 [-1, 128, 15, 15] 147,456\n",
" BatchNorm2d-44 [-1, 128, 15, 15] 256\n",
" Identity-45 [-1, 128, 15, 15] 0\n",
" ReLU-46 [-1, 128, 15, 15] 0\n",
" Identity-47 [-1, 128, 15, 15] 0\n",
" Conv2d-48 [-1, 128, 15, 15] 147,456\n",
" BatchNorm2d-49 [-1, 128, 15, 15] 256\n",
" ReLU-50 [-1, 128, 15, 15] 0\n",
" BasicBlock-51 [-1, 128, 15, 15] 0\n",
" Conv2d-52 [-1, 128, 15, 15] 147,456\n",
" BatchNorm2d-53 [-1, 128, 15, 15] 256\n",
" Identity-54 [-1, 128, 15, 15] 0\n",
" ReLU-55 [-1, 128, 15, 15] 0\n",
" Identity-56 [-1, 128, 15, 15] 0\n",
" Conv2d-57 [-1, 128, 15, 15] 147,456\n",
" BatchNorm2d-58 [-1, 128, 15, 15] 256\n",
" ReLU-59 [-1, 128, 15, 15] 0\n",
" BasicBlock-60 [-1, 128, 15, 15] 0\n",
" Conv2d-61 [-1, 128, 15, 15] 147,456\n",
" BatchNorm2d-62 [-1, 128, 15, 15] 256\n",
" Identity-63 [-1, 128, 15, 15] 0\n",
" ReLU-64 [-1, 128, 15, 15] 0\n",
" Identity-65 [-1, 128, 15, 15] 0\n",
" Conv2d-66 [-1, 128, 15, 15] 147,456\n",
" BatchNorm2d-67 [-1, 128, 15, 15] 256\n",
" ReLU-68 [-1, 128, 15, 15] 0\n",
" BasicBlock-69 [-1, 128, 15, 15] 0\n",
" Conv2d-70 [-1, 256, 8, 8] 294,912\n",
" BatchNorm2d-71 [-1, 256, 8, 8] 512\n",
" Identity-72 [-1, 256, 8, 8] 0\n",
" ReLU-73 [-1, 256, 8, 8] 0\n",
" Identity-74 [-1, 256, 8, 8] 0\n",
" Conv2d-75 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-76 [-1, 256, 8, 8] 512\n",
" Conv2d-77 [-1, 256, 8, 8] 32,768\n",
" BatchNorm2d-78 [-1, 256, 8, 8] 512\n",
" ReLU-79 [-1, 256, 8, 8] 0\n",
" BasicBlock-80 [-1, 256, 8, 8] 0\n",
" Conv2d-81 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-82 [-1, 256, 8, 8] 512\n",
" Identity-83 [-1, 256, 8, 8] 0\n",
" ReLU-84 [-1, 256, 8, 8] 0\n",
" Identity-85 [-1, 256, 8, 8] 0\n",
" Conv2d-86 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-87 [-1, 256, 8, 8] 512\n",
" ReLU-88 [-1, 256, 8, 8] 0\n",
" BasicBlock-89 [-1, 256, 8, 8] 0\n",
" Conv2d-90 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-91 [-1, 256, 8, 8] 512\n",
" Identity-92 [-1, 256, 8, 8] 0\n",
" ReLU-93 [-1, 256, 8, 8] 0\n",
" Identity-94 [-1, 256, 8, 8] 0\n",
" Conv2d-95 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-96 [-1, 256, 8, 8] 512\n",
" ReLU-97 [-1, 256, 8, 8] 0\n",
" BasicBlock-98 [-1, 256, 8, 8] 0\n",
" Conv2d-99 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-100 [-1, 256, 8, 8] 512\n",
" Identity-101 [-1, 256, 8, 8] 0\n",
" ReLU-102 [-1, 256, 8, 8] 0\n",
" Identity-103 [-1, 256, 8, 8] 0\n",
" Conv2d-104 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-105 [-1, 256, 8, 8] 512\n",
" ReLU-106 [-1, 256, 8, 8] 0\n",
" BasicBlock-107 [-1, 256, 8, 8] 0\n",
" Conv2d-108 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-109 [-1, 256, 8, 8] 512\n",
" Identity-110 [-1, 256, 8, 8] 0\n",
" ReLU-111 [-1, 256, 8, 8] 0\n",
" Identity-112 [-1, 256, 8, 8] 0\n",
" Conv2d-113 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-114 [-1, 256, 8, 8] 512\n",
" ReLU-115 [-1, 256, 8, 8] 0\n",
" BasicBlock-116 [-1, 256, 8, 8] 0\n",
" Conv2d-117 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-118 [-1, 256, 8, 8] 512\n",
" Identity-119 [-1, 256, 8, 8] 0\n",
" ReLU-120 [-1, 256, 8, 8] 0\n",
" Identity-121 [-1, 256, 8, 8] 0\n",
" Conv2d-122 [-1, 256, 8, 8] 589,824\n",
" BatchNorm2d-123 [-1, 256, 8, 8] 512\n",
" ReLU-124 [-1, 256, 8, 8] 0\n",
" BasicBlock-125 [-1, 256, 8, 8] 0\n",
" Conv2d-126 [-1, 512, 4, 4] 1,179,648\n",
" BatchNorm2d-127 [-1, 512, 4, 4] 1,024\n",
" Identity-128 [-1, 512, 4, 4] 0\n",
" ReLU-129 [-1, 512, 4, 4] 0\n",
" Identity-130 [-1, 512, 4, 4] 0\n",
" Conv2d-131 [-1, 512, 4, 4] 2,359,296\n",
" BatchNorm2d-132 [-1, 512, 4, 4] 1,024\n",
" Conv2d-133 [-1, 512, 4, 4] 131,072\n",
" BatchNorm2d-134 [-1, 512, 4, 4] 1,024\n",
" ReLU-135 [-1, 512, 4, 4] 0\n",
" BasicBlock-136 [-1, 512, 4, 4] 0\n",
" Conv2d-137 [-1, 512, 4, 4] 2,359,296\n",
" BatchNorm2d-138 [-1, 512, 4, 4] 1,024\n",
" Identity-139 [-1, 512, 4, 4] 0\n",
" ReLU-140 [-1, 512, 4, 4] 0\n",
" Identity-141 [-1, 512, 4, 4] 0\n",
" Conv2d-142 [-1, 512, 4, 4] 2,359,296\n",
" BatchNorm2d-143 [-1, 512, 4, 4] 1,024\n",
" ReLU-144 [-1, 512, 4, 4] 0\n",
" BasicBlock-145 [-1, 512, 4, 4] 0\n",
" Conv2d-146 [-1, 512, 4, 4] 2,359,296\n",
" BatchNorm2d-147 [-1, 512, 4, 4] 1,024\n",
" Identity-148 [-1, 512, 4, 4] 0\n",
" ReLU-149 [-1, 512, 4, 4] 0\n",
" Identity-150 [-1, 512, 4, 4] 0\n",
" Conv2d-151 [-1, 512, 4, 4] 2,359,296\n",
" BatchNorm2d-152 [-1, 512, 4, 4] 1,024\n",
" ReLU-153 [-1, 512, 4, 4] 0\n",
" BasicBlock-154 [-1, 512, 4, 4] 0\n",
" BatchNorm2d-155 [-1, 512, 4, 4] 1,024\n",
" ReLU-156 [-1, 512, 4, 4] 0\n",
" Conv2d-157 [-1, 1024, 4, 4] 4,719,616\n",
" ReLU-158 [-1, 1024, 4, 4] 0\n",
" Conv2d-159 [-1, 512, 4, 4] 4,719,104\n",
" ReLU-160 [-1, 512, 4, 4] 0\n",
" Conv2d-161 [-1, 1024, 4, 4] 525,312\n",
" ReLU-162 [-1, 1024, 4, 4] 0\n",
" PixelShuffle-163 [-1, 256, 8, 8] 0\n",
" BatchNorm2d-164 [-1, 256, 8, 8] 512\n",
" ReLU-165 [-1, 512, 8, 8] 0\n",
" Conv2d-166 [-1, 512, 8, 8] 2,359,808\n",
" ReLU-167 [-1, 512, 8, 8] 0\n",
" Conv2d-168 [-1, 512, 8, 8] 2,359,808\n",
" ReLU-169 [-1, 512, 8, 8] 0\n",
" UnetBlock-170 [-1, 512, 8, 8] 0\n",
" Conv2d-171 [-1, 1024, 8, 8] 525,312\n",
" ReLU-172 [-1, 1024, 8, 8] 0\n",
" PixelShuffle-173 [-1, 256, 16, 16] 0\n",
" BatchNorm2d-174 [-1, 128, 15, 15] 256\n",
" ReLU-175 [-1, 384, 15, 15] 0\n",
" Conv2d-176 [-1, 384, 15, 15] 1,327,488\n",
" ReLU-177 [-1, 384, 15, 15] 0\n",
" Conv2d-178 [-1, 384, 15, 15] 1,327,488\n",
" ReLU-179 [-1, 384, 15, 15] 0\n",
" UnetBlock-180 [-1, 384, 15, 15] 0\n",
" Conv2d-181 [-1, 768, 15, 15] 295,680\n",
" ReLU-182 [-1, 768, 15, 15] 0\n",
" PixelShuffle-183 [-1, 192, 30, 30] 0\n",
" BatchNorm2d-184 [-1, 64, 30, 30] 128\n",
" ReLU-185 [-1, 256, 30, 30] 0\n",
" Conv2d-186 [-1, 256, 30, 30] 590,080\n",
" ReLU-187 [-1, 256, 30, 30] 0\n",
" Conv2d-188 [-1, 256, 30, 30] 590,080\n",
" ReLU-189 [-1, 256, 30, 30] 0\n",
" UnetBlock-190 [-1, 256, 30, 30] 0\n",
" Conv2d-191 [-1, 512, 30, 30] 131,584\n",
" ReLU-192 [-1, 512, 30, 30] 0\n",
" PixelShuffle-193 [-1, 128, 60, 60] 0\n",
" BatchNorm2d-194 [-1, 64, 60, 60] 128\n",
" ReLU-195 [-1, 192, 60, 60] 0\n",
" Conv2d-196 [-1, 96, 60, 60] 165,984\n",
" ReLU-197 [-1, 96, 60, 60] 0\n",
" Conv2d-198 [-1, 96, 60, 60] 83,040\n",
" ReLU-199 [-1, 96, 60, 60] 0\n",
" UnetBlock-200 [-1, 96, 60, 60] 0\n",
" Conv2d-201 [-1, 384, 60, 60] 37,248\n",
" ReLU-202 [-1, 384, 60, 60] 0\n",
" PixelShuffle-203 [-1, 96, 120, 120] 0\n",
" ResizeToOrig-204 [-1, 96, 120, 120] 0\n",
" MergeLayer-205 [-1, 106, 120, 120] 0\n",
" Conv2d-206 [-1, 106, 120, 120] 101,230\n",
" ReLU-207 [-1, 106, 120, 120] 0\n",
" Conv2d-208 [-1, 106, 120, 120] 101,230\n",
" ReLU-209 [-1, 106, 120, 120] 0\n",
" ResBlock-210 [-1, 106, 120, 120] 0\n",
" Conv2d-211 [-1, 1, 120, 120] 107\n",
" ToTensorBase-212 [-1, 1, 120, 120] 0\n",
"================================================================\n",
"Total params: 41,268,871\n",
"Trainable params: 41,268,871\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 0.55\n",
"Forward/backward pass size (MB): 200.90\n",
"Params size (MB): 157.43\n",
"Estimated Total Size (MB): 358.88\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"# displays layers and parameters of models\n",
"resnet = timm.create_model(\"resnet34\", pretrained=True)\n",
"resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(\n",
" 7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
"\n",
"m = resnet\n",
"m = nn.Sequential(*list(m.children())[:-2])\n",
"model = DynamicUnet(m, 1, (120, 120), norm_type=None).to(DEVICE)\n",
"\n",
"from torchsummary import summary\n",
"summary(model, (10, 120, 120))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
