https://github.com/turleyjm/cell-division-dl-plugin
Tip revision: f6c3290670f23524d47ccddc873ef5673eda4b3c authored by turleyjm on 11 July 2024, 02:09:47 UTC
Update README.md
Update README.md
Tip revision: f6c3290
testingUNetCellDivision3.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.vision.all import *\n",
"import cv2\n",
"import skimage as sm\n",
"import skimage.io\n",
"import tifffile\n",
"path = Path('/notebooks/deepLearningPaper/UNetCellDivision3Eval')\n",
"\n",
"valid_fnames = (path/'valid.txt').read_text().split('\\n')\n",
"path_im = path/'dat3/train_images'\n",
"path_lbl = path/'dat3/train_masks'\n",
"codes = np.loadtxt(path/'codes.txt', dtype=str)\n",
"# codes = np.loadtxt(path/'codes.txt', dtype=str)\n",
"fnames = get_image_files(path_im)\n",
"lbl_names = get_image_files(path_lbl)\n",
"get_msk = lambda o: path/f'dat3/train_masks'/f'{o.stem}_mask{o.suffix}'\n",
"\n",
"camvid = DataBlock(blocks=(ImageBlock, MaskBlock(codes)),\n",
" get_items=get_image_files,\n",
" splitter=FileSplitter(path/'valid.txt'),\n",
" get_y=get_msk,\n",
" batch_tfms=[*aug_transforms(), Normalize.from_stats(*imagenet_stats)])\n",
"\n",
"dls = camvid.dataloaders(path/f'dat3/train_images', bs=8)\n",
"dls.vocab = codes\n",
"name2id = {v:k for k,v in enumerate(codes)}\n",
"void_code = name2id['Void']\n",
"def acc_camvid(inp, targ):\n",
" targ = targ.squeeze(1)\n",
" mask = targ != void_code\n",
" return (inp.argmax(dim=1)[mask]==targ[mask]).float().mean() \n",
"opt = ranger\n",
"learn = unet_learner(dls, resnet34, metrics=acc_camvid, self_attention=True, act_cls=Mish, opt_func=opt).to_fp16()\n",
"\n",
"\n",
"device=\"cuda\"\n",
"learn.load(f'UNetCellDivision3')\n",
"model = learn.model\n",
"model = model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from cgitb import reset\n",
"import torch\n",
"from tqdm import tqdm\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchvision\n",
"from torch.utils.data import DataLoader\n",
"import os\n",
"from PIL import Image\n",
"from torch.utils.data import Dataset\n",
"import numpy as np\n",
"import skimage as sm\n",
"import skimage.io\n",
"from matplotlib import pyplot as plt\n",
"import tifffile\n",
"from fastai.vision.all import *\n",
"from skimage.feature import blob, blob_dog, blob_log, blob_doh\n",
"from os.path import exists\n",
"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"BATCH_SIZE = 8\n",
"NUM_WORKERS = 2\n",
"IMAGE_HEIGHT = 512\n",
"IMAGE_WIDTH = 512\n",
"PIN_MEMORY = True\n",
"\n",
"# util\n",
"\n",
"def createFolder(directory):\n",
" try:\n",
" if not os.path.exists(directory):\n",
" os.makedirs(directory)\n",
" except OSError:\n",
" print(\"Error: Creating directory. \" + directory)\n",
"\n",
" \n",
"def sortDL(df, t, x, y):\n",
"\n",
" a = df[df[\"T\"] == t - 2]\n",
" b = df[df[\"T\"] == t - 1]\n",
" c = df[df[\"T\"] == t + 1]\n",
" d = df[df[\"T\"] == t + 2]\n",
"\n",
" df = pd.concat([a, b, c, d])\n",
"\n",
" xMax = x + 13\n",
" xMin = x - 13\n",
" yMax = y + 13\n",
" yMin = y - 13\n",
" if xMax > 511:\n",
" xMax = 511\n",
" if yMax > 511:\n",
" yMax = 511\n",
" if xMin < 0:\n",
" xMin = 0\n",
" if yMin < 0:\n",
" yMin = 0\n",
"\n",
" dfxmin = df[df[\"X\"] >= xMin]\n",
" dfx = dfxmin[dfxmin[\"X\"] < xMax]\n",
"\n",
" dfymin = dfx[dfx[\"Y\"] >= yMin]\n",
" df = dfymin[dfymin[\"Y\"] < yMax]\n",
"\n",
" return df\n",
"\n",
"\n",
"def intensity(vid, ti, xi, yi):\n",
"\n",
" [T, X, Y] = vid.shape\n",
"\n",
" vidBoundary = np.zeros([T, 552, 552])\n",
"\n",
" for x in range(X):\n",
" for y in range(Y):\n",
" vidBoundary[:, 20 + x, 20 + y] = vid[:, x, y]\n",
"\n",
" rr, cc = sm.draw.disk([yi + 20, xi + 20], 9)\n",
" div = vidBoundary[ti][rr, cc]\n",
" div = div[div > 0]\n",
"\n",
" mu = np.mean(div)\n",
"\n",
" return mu\n",
" \n",
" \n",
"def main():\n",
" \n",
" cwd = os.getcwd()\n",
" filenames = os.listdir(cwd + \"/dat_pred\")\n",
" if \".DS_Store\" in filenames:\n",
" filenames.remove(\".DS_Store\")\n",
" if \".ipynb_checkpoints\" in filenames:\n",
" filenames.remove(\".ipynb_checkpoints\")\n",
" filenames.sort()\n",
"\n",
" label=0\n",
" for filename in filenames:\n",
" print(filename)\n",
" path_to_file = f\"dat_output/{filename}/dfDivisions{filename}.pkl\"\n",
" focus = sm.io.imread(f\"dat_pred/{filename}/focus{filename}.tif\").astype(int)\n",
" T = focus.shape[0]-4\n",
" if False == exists(path_to_file):\n",
" createFolder(f\"dat_output/{filename}\")\n",
" for t in range(T):\n",
" out = learn.predict(f'dat_pred/{filename}/{filename}_{t}.tif')[2][1]\n",
" out = np.asarray(out*255, \"uint8\")\n",
" tifffile.imwrite(f\"dat_output/{filename}/pred_{filename}_{t}.tif\", out)\n",
"\n",
" vid = np.zeros([T, 512, 512])\n",
"\n",
" img = np.zeros([552, 552])\n",
" vid[0] = sm.io.imread(f\"dat_output/{filename}/pred_{filename}_{0}.tif\").astype(int)\n",
" img[20:532, 20:532] = vid[0]\n",
" blobs = blob_log(img, min_sigma=10, max_sigma=25, num_sigma=25, threshold=30)\n",
" blobs_logs = np.concatenate((blobs, np.zeros([len(blobs), 1])), axis=1)\n",
"\n",
" for t in range(1, T):\n",
" img = np.zeros([552, 552]) \n",
" vid[t] = sm.io.imread(f\"dat_output/{filename}/pred_{filename}_{t}.tif\").astype(int)\n",
" img[20:532, 20:532] = vid[t]\n",
" blobs = blob_log(img, min_sigma=10, max_sigma=25, num_sigma=25, threshold=30)\n",
" blobs_log = np.concatenate((blobs, np.zeros([len(blobs), 1]) + t), axis=1)\n",
" blobs_logs = np.concatenate((blobs_logs, blobs_log))\n",
"\n",
"\n",
" _df = []\n",
" for blob in blobs_logs:\n",
" y, x, r, t = blob\n",
" mu = intensity(vid, int(t), int(x - 20), int(y - 20))\n",
"\n",
" _df.append(\n",
" {\n",
" \"Label\": label,\n",
" \"T\": int(t + 1),\n",
" \"X\": int(x - 20),\n",
" \"Y\": 532 - int(y), # map coords without boundary\n",
" \"Intensity\": mu,\n",
" }\n",
" )\n",
" label += 1\n",
"\n",
" df = pd.DataFrame(_df)\n",
" df.to_pickle(f\"dat_output/{filename}/_dfDivisions{filename}.pkl\")\n",
" dfRemove = pd.read_pickle(f\"dat_output/{filename}/_dfDivisions{filename}.pkl\")\n",
"\n",
" for i in range(len(df)):\n",
" ti, xi, yi = df[\"T\"].iloc[i], df[\"X\"].iloc[i], df[\"Y\"].iloc[i]\n",
" labeli = df[\"Label\"].iloc[i]\n",
" dfmulti = sortDL(df, ti, xi, yi)\n",
" dfmulti = dfmulti.drop_duplicates(subset=[\"T\", \"X\", \"Y\"])\n",
"\n",
" if len(dfmulti) > 0:\n",
" mui = df[\"Intensity\"].iloc[i]\n",
" for j in range(len(dfmulti)):\n",
" tj, xj, yj = (\n",
" dfmulti[\"T\"].iloc[j],\n",
" dfmulti[\"X\"].iloc[j],\n",
" dfmulti[\"Y\"].iloc[j],\n",
" )\n",
" labelj = dfmulti[\"Label\"].iloc[j]\n",
" muj = dfmulti[\"Intensity\"].iloc[j]\n",
"\n",
" if mui < muj:\n",
" indexNames = dfRemove[dfRemove[\"Label\"] == labeli].index\n",
" dfRemove.drop(indexNames, inplace=True)\n",
" else:\n",
" indexNames = dfRemove[dfRemove[\"Label\"] == labelj].index\n",
" dfRemove.drop(indexNames, inplace=True)\n",
"\n",
" dfDivisions = dfRemove.drop_duplicates(subset=[\"T\", \"X\", \"Y\"])\n",
" dfDivisions.to_pickle(f\"dat_output/{filename}/dfDivisions{filename}.pkl\")\n",
" os.remove(f\"dat_output/{filename}/_dfDivisions{filename}.pkl\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Unwound18h19\n"
]
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='0' class='' max='1' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 0.00% [0/1 00:00<00:00]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-7-263240bbee7e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-6-7ba21be3c4c1>\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0mcreateFolder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"dat_output/{filename}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 106\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'dat_pred/{filename}/{filename}_{t}.tif'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 107\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;36m255\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"uint8\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 108\u001b[0m \u001b[0mtifffile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"dat_output/{filename}/pred_{filename}_{t}.tif\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, item, rm_type_tfms, with_input)\u001b[0m\n\u001b[1;32m 264\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpredict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mitem\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrm_type_tfms\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwith_input\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 265\u001b[0m \u001b[0mdl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_dl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrm_type_tfms\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrm_type_tfms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_workers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 266\u001b[0;31m \u001b[0minp\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mpreds\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdec_preds\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_preds\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwith_input\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwith_decoded\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 267\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdls\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'n_inp'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 268\u001b[0m \u001b[0minp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minp\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m1\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mtuplify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36mget_preds\u001b[0;34m(self, ds_idx, dl, with_input, with_decoded, with_loss, act, inner, reorder, cbs, **kwargs)\u001b[0m\n\u001b[1;32m 251\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mwith_loss\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mctx_mgrs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_not_reduced\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mContextManagers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mctx_mgrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 253\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_epoch_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 254\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mact\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mact\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'activation'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnoop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall_tensors\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36m_do_epoch_validate\u001b[0;34m(self, ds_idx, dl)\u001b[0m\n\u001b[1;32m 201\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mdl\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdls\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mds_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdl\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdl\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 203\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_with_events\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall_batches\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'validate'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCancelValidException\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 204\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 205\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36m_with_events\u001b[0;34m(self, f, event_type, ex, final)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_with_events\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevent_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfinal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnoop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'before_{event_type}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'after_cancel_{event_type}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'after_{event_type}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mfinal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36mall_batches\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mall_batches\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_iter\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdl\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mone_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36mone_batch\u001b[0;34m(self, i, b)\u001b[0m\n\u001b[1;32m 192\u001b[0m \u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_set_device\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 194\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_with_events\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_do_one_batch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'batch'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCancelBatchException\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 195\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 196\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_epoch_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36m_with_events\u001b[0;34m(self, f, event_type, ex, final)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 162\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_with_events\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevent_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfinal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnoop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 163\u001b[0;31m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'before_{event_type}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 164\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mex\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'after_cancel_{event_type}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'after_{event_type}'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mfinal\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/learner.py\u001b[0m in \u001b[0;36m_do_one_batch\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 170\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_do_one_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 172\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 173\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'after_pred'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 174\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/layers.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 405\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ml\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 406\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0morig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 407\u001b[0;31m \u001b[0mnres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ml\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 408\u001b[0m \u001b[0;31m# We have to remove res.orig to avoid hanging refs and therefore memory leaks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 409\u001b[0m \u001b[0mres\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0morig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnres\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0morig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/fastai/vision/models/unet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, up_in)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mup_in\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0mup_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshuf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mup_in\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m \u001b[0mssh\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 37\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mssh\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mup_out\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 422\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 423\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 424\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 425\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[0;34m(self, input, weight)\u001b[0m\n\u001b[1;32m 417\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstride\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 418\u001b[0m _pair(0), self.dilation, self.groups)\n\u001b[0;32m--> 419\u001b[0;31m return F.conv2d(input, weight, self.bias, self.stride,\n\u001b[0m\u001b[1;32m 420\u001b[0m self.padding, self.dilation, self.groups)\n\u001b[1;32m 421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"main()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 256, 256] 9,408\n",
" BatchNorm2d-2 [-1, 64, 256, 256] 128\n",
" ReLU-3 [-1, 64, 256, 256] 0\n",
" MaxPool2d-4 [-1, 64, 128, 128] 0\n",
" Conv2d-5 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-6 [-1, 64, 128, 128] 128\n",
" ReLU-7 [-1, 64, 128, 128] 0\n",
" Conv2d-8 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-9 [-1, 64, 128, 128] 128\n",
" ReLU-10 [-1, 64, 128, 128] 0\n",
" BasicBlock-11 [-1, 64, 128, 128] 0\n",
" Conv2d-12 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-13 [-1, 64, 128, 128] 128\n",
" ReLU-14 [-1, 64, 128, 128] 0\n",
" Conv2d-15 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-16 [-1, 64, 128, 128] 128\n",
" ReLU-17 [-1, 64, 128, 128] 0\n",
" BasicBlock-18 [-1, 64, 128, 128] 0\n",
" Conv2d-19 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-20 [-1, 64, 128, 128] 128\n",
" ReLU-21 [-1, 64, 128, 128] 0\n",
" Conv2d-22 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-23 [-1, 64, 128, 128] 128\n",
" ReLU-24 [-1, 64, 128, 128] 0\n",
" BasicBlock-25 [-1, 64, 128, 128] 0\n",
" Conv2d-26 [-1, 128, 64, 64] 73,728\n",
" BatchNorm2d-27 [-1, 128, 64, 64] 256\n",
" ReLU-28 [-1, 128, 64, 64] 0\n",
" Conv2d-29 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-30 [-1, 128, 64, 64] 256\n",
" Conv2d-31 [-1, 128, 64, 64] 8,192\n",
" BatchNorm2d-32 [-1, 128, 64, 64] 256\n",
" ReLU-33 [-1, 128, 64, 64] 0\n",
" BasicBlock-34 [-1, 128, 64, 64] 0\n",
" Conv2d-35 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-36 [-1, 128, 64, 64] 256\n",
" ReLU-37 [-1, 128, 64, 64] 0\n",
" Conv2d-38 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-39 [-1, 128, 64, 64] 256\n",
" ReLU-40 [-1, 128, 64, 64] 0\n",
" BasicBlock-41 [-1, 128, 64, 64] 0\n",
" Conv2d-42 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-43 [-1, 128, 64, 64] 256\n",
" ReLU-44 [-1, 128, 64, 64] 0\n",
" Conv2d-45 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-46 [-1, 128, 64, 64] 256\n",
" ReLU-47 [-1, 128, 64, 64] 0\n",
" BasicBlock-48 [-1, 128, 64, 64] 0\n",
" Conv2d-49 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-50 [-1, 128, 64, 64] 256\n",
" ReLU-51 [-1, 128, 64, 64] 0\n",
" Conv2d-52 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-53 [-1, 128, 64, 64] 256\n",
" ReLU-54 [-1, 128, 64, 64] 0\n",
" BasicBlock-55 [-1, 128, 64, 64] 0\n",
" Conv2d-56 [-1, 256, 32, 32] 294,912\n",
" BatchNorm2d-57 [-1, 256, 32, 32] 512\n",
" ReLU-58 [-1, 256, 32, 32] 0\n",
" Conv2d-59 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-60 [-1, 256, 32, 32] 512\n",
" Conv2d-61 [-1, 256, 32, 32] 32,768\n",
" BatchNorm2d-62 [-1, 256, 32, 32] 512\n",
" ReLU-63 [-1, 256, 32, 32] 0\n",
" BasicBlock-64 [-1, 256, 32, 32] 0\n",
" Conv2d-65 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-66 [-1, 256, 32, 32] 512\n",
" ReLU-67 [-1, 256, 32, 32] 0\n",
" Conv2d-68 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-69 [-1, 256, 32, 32] 512\n",
" ReLU-70 [-1, 256, 32, 32] 0\n",
" BasicBlock-71 [-1, 256, 32, 32] 0\n",
" Conv2d-72 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-73 [-1, 256, 32, 32] 512\n",
" ReLU-74 [-1, 256, 32, 32] 0\n",
" Conv2d-75 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-76 [-1, 256, 32, 32] 512\n",
" ReLU-77 [-1, 256, 32, 32] 0\n",
" BasicBlock-78 [-1, 256, 32, 32] 0\n",
" Conv2d-79 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-80 [-1, 256, 32, 32] 512\n",
" ReLU-81 [-1, 256, 32, 32] 0\n",
" Conv2d-82 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-83 [-1, 256, 32, 32] 512\n",
" ReLU-84 [-1, 256, 32, 32] 0\n",
" BasicBlock-85 [-1, 256, 32, 32] 0\n",
" Conv2d-86 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-87 [-1, 256, 32, 32] 512\n",
" ReLU-88 [-1, 256, 32, 32] 0\n",
" Conv2d-89 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-90 [-1, 256, 32, 32] 512\n",
" ReLU-91 [-1, 256, 32, 32] 0\n",
" BasicBlock-92 [-1, 256, 32, 32] 0\n",
" Conv2d-93 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-94 [-1, 256, 32, 32] 512\n",
" ReLU-95 [-1, 256, 32, 32] 0\n",
" Conv2d-96 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-97 [-1, 256, 32, 32] 512\n",
" ReLU-98 [-1, 256, 32, 32] 0\n",
" BasicBlock-99 [-1, 256, 32, 32] 0\n",
" Conv2d-100 [-1, 512, 16, 16] 1,179,648\n",
" BatchNorm2d-101 [-1, 512, 16, 16] 1,024\n",
" ReLU-102 [-1, 512, 16, 16] 0\n",
" Conv2d-103 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-104 [-1, 512, 16, 16] 1,024\n",
" Conv2d-105 [-1, 512, 16, 16] 131,072\n",
" BatchNorm2d-106 [-1, 512, 16, 16] 1,024\n",
" ReLU-107 [-1, 512, 16, 16] 0\n",
" BasicBlock-108 [-1, 512, 16, 16] 0\n",
" Conv2d-109 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-110 [-1, 512, 16, 16] 1,024\n",
" ReLU-111 [-1, 512, 16, 16] 0\n",
" Conv2d-112 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-113 [-1, 512, 16, 16] 1,024\n",
" ReLU-114 [-1, 512, 16, 16] 0\n",
" BasicBlock-115 [-1, 512, 16, 16] 0\n",
" Conv2d-116 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-117 [-1, 512, 16, 16] 1,024\n",
" ReLU-118 [-1, 512, 16, 16] 0\n",
" Conv2d-119 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-120 [-1, 512, 16, 16] 1,024\n",
" ReLU-121 [-1, 512, 16, 16] 0\n",
" BasicBlock-122 [-1, 512, 16, 16] 0\n",
" BatchNorm2d-123 [-1, 512, 16, 16] 1,024\n",
" ReLU-124 [-1, 512, 16, 16] 0\n",
" Conv2d-125 [-1, 1024, 16, 16] 4,719,616\n",
" Mish-126 [-1, 1024, 16, 16] 0\n",
" Conv2d-127 [-1, 512, 16, 16] 4,719,104\n",
" Mish-128 [-1, 512, 16, 16] 0\n",
" Conv2d-129 [-1, 1024, 16, 16] 525,312\n",
" Mish-130 [-1, 1024, 16, 16] 0\n",
" PixelShuffle-131 [-1, 256, 32, 32] 0\n",
" BatchNorm2d-132 [-1, 256, 32, 32] 512\n",
" Mish-133 [-1, 512, 32, 32] 0\n",
" Conv2d-134 [-1, 512, 32, 32] 2,359,808\n",
" Mish-135 [-1, 512, 32, 32] 0\n",
" Conv2d-136 [-1, 512, 32, 32] 2,359,808\n",
" Mish-137 [-1, 512, 32, 32] 0\n",
" UnetBlock-138 [-1, 512, 32, 32] 0\n",
" Conv2d-139 [-1, 1024, 32, 32] 525,312\n",
" Mish-140 [-1, 1024, 32, 32] 0\n",
" PixelShuffle-141 [-1, 256, 64, 64] 0\n",
" BatchNorm2d-142 [-1, 128, 64, 64] 256\n",
" Mish-143 [-1, 384, 64, 64] 0\n",
" Conv2d-144 [-1, 384, 64, 64] 1,327,488\n",
" Mish-145 [-1, 384, 64, 64] 0\n",
" Conv2d-146 [-1, 384, 64, 64] 1,327,488\n",
" Mish-147 [-1, 384, 64, 64] 0\n",
" Conv1d-148 [-1, 48, 4096] 18,432\n",
" Conv1d-149 [-1, 48, 4096] 18,432\n",
" Conv1d-150 [-1, 384, 4096] 147,456\n",
" SelfAttention-151 [-1, 384, 64, 64] 0\n",
" UnetBlock-152 [-1, 384, 64, 64] 0\n",
" Conv2d-153 [-1, 768, 64, 64] 295,680\n",
" Mish-154 [-1, 768, 64, 64] 0\n",
" PixelShuffle-155 [-1, 192, 128, 128] 0\n",
" BatchNorm2d-156 [-1, 64, 128, 128] 128\n",
" Mish-157 [-1, 256, 128, 128] 0\n",
" Conv2d-158 [-1, 256, 128, 128] 590,080\n",
" Mish-159 [-1, 256, 128, 128] 0\n",
" Conv2d-160 [-1, 256, 128, 128] 590,080\n",
" Mish-161 [-1, 256, 128, 128] 0\n",
" UnetBlock-162 [-1, 256, 128, 128] 0\n",
" Conv2d-163 [-1, 512, 128, 128] 131,584\n",
" Mish-164 [-1, 512, 128, 128] 0\n",
" PixelShuffle-165 [-1, 128, 256, 256] 0\n",
" BatchNorm2d-166 [-1, 64, 256, 256] 128\n",
" Mish-167 [-1, 192, 256, 256] 0\n",
" Conv2d-168 [-1, 96, 256, 256] 165,984\n",
" Mish-169 [-1, 96, 256, 256] 0\n",
" Conv2d-170 [-1, 96, 256, 256] 83,040\n",
" Mish-171 [-1, 96, 256, 256] 0\n",
" UnetBlock-172 [-1, 96, 256, 256] 0\n",
" Conv2d-173 [-1, 384, 256, 256] 37,248\n",
" Mish-174 [-1, 384, 256, 256] 0\n",
" PixelShuffle-175 [-1, 96, 512, 512] 0\n",
" ResizeToOrig-176 [-1, 96, 512, 512] 0\n",
" MergeLayer-177 [-1, 99, 512, 512] 0\n",
" Conv2d-178 [-1, 99, 512, 512] 88,308\n",
" Mish-179 [-1, 99, 512, 512] 0\n",
" Conv2d-180 [-1, 99, 512, 512] 88,308\n",
" Mish-181 [-1, 99, 512, 512] 0\n",
" ResBlock-182 [-1, 99, 512, 512] 0\n",
" Conv2d-183 [-1, 3, 512, 512] 300\n",
" ToTensorBase-184 [-1, 3, 512, 512] 0\n",
"================================================================\n",
"Total params: 41,405,588\n",
"Trainable params: 20,137,940\n",
"Non-trainable params: 21,267,648\n",
"----------------------------------------------------------------\n",
"Input size (MB): 3.00\n",
"Forward/backward pass size (MB): 3470.00\n",
"Params size (MB): 157.95\n",
"Estimated Total Size (MB): 3630.95\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"model = learn.model\n",
"from torchsummary import summary\n",
"summary(model, (3, 512, 512))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
