https://github.com/turleyjm/cell-division-dl-plugin
Tip revision: f6c3290670f23524d47ccddc873ef5673eda4b3c authored by turleyjm on 11 July 2024, 02:09:47 UTC
Update README.md
Update README.md
Tip revision: f6c3290
trainingUNetCellDivision10.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "module 'wandb.proto.wandb_internal_pb2' has no attribute 'Result'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-2557cbe562cb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtifffile\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mfastai\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mversion\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m__version__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mis_scriptable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_exportable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mset_scriptable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mset_exportable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist_models\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist_pretrained\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist_modules\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_entrypoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mis_model_pretrained\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_pretrained_cfg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_pretrained_cfg_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/models/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbeit\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbyoanet\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbyobnet\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcait\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcoat\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/models/beit.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcheckpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIMAGENET_DEFAULT_MEAN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIMAGENET_DEFAULT_STD\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPatchEmbed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMlp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSwiGLU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLayerNorm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mDropPath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrunc_normal_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muse_fused_attn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mresample_patch_embed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresample_abs_pos_embed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresize_rel_pos_bias_table\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mresolve_data_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresolve_model_data_config\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mconstants\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImageDataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIterableImageDataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAugMixDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdataset_factory\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdataset_info\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDatasetInfo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCustomDatasetInfo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/dataset.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mPIL\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mreaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_reader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0m_logger\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetLogger\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/readers/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mreader_factory\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_reader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mimg_extensions\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/readers/reader_factory.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mreader_image_folder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mReaderImageFolder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mreader_image_in_tar\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mReaderImageInTar\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/readers/reader_image_folder.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtyping\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmisc\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnatural_key\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mclass_map\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_class_map\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/utils/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmodel_ema\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mModelEma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mModelEmaV2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrandom_seed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0msummary\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mupdate_summary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_outdir\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/utils/summary.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mcollections\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mterm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtermsetup\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtermlog\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtermerror\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtermwarn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msdk\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mwandb_sdk\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mwandb_helper\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mhelper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0martifacts\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0martifact\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mArtifact\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mwandb_alerts\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAlertLevel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mwandb_config\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mConfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/artifacts/artifact.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 36\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdata_types\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mutil\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormalize\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnormalize_exceptions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 38\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpublic\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mArtifactCollection\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mArtifactFiles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRetryingClient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_types\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWBValue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/apis/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0mreset_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvendor_setup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0minternal\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mApi\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mInternalApi\u001b[0m \u001b[0;31m# noqa\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mpublic\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mApi\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mPublicApi\u001b[0m \u001b[0;31m# noqa\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/apis/internal.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtyping\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msdk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minternal\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minternal_api\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mApi\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mInternalApi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msdk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhashutil\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mB64MD5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5_file_b64\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mretry\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilenames\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDIFF_FNAME\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMETADATA_FNAME\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgitlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mGitRepo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/lib/retry.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCheckRetryFnType\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmailbox\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mContextCancelledError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mlogger\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetLogger\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/lib/mailbox.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0m_MailboxSlot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 103\u001b[0m \u001b[0m_result\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mResult\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0m_event\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthreading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEvent\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/lib/mailbox.py\u001b[0m in \u001b[0;36m_MailboxSlot\u001b[0;34m()\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_MailboxSlot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0m_result\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mResult\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0m_event\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthreading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEvent\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthreading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLock\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: module 'wandb.proto.wandb_internal_pb2' has no attribute 'Result'"
]
}
],
"source": [
"from cgitb import reset\n",
"import torch\n",
"import albumentations as A\n",
"from albumentations.pytorch import ToTensorV2\n",
"from tqdm import tqdm\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"import torchvision\n",
"from torch.utils.data import DataLoader\n",
"import os\n",
"from PIL import Image\n",
"from torch.utils.data import Dataset\n",
"import numpy as np\n",
"import skimage as sm\n",
"import skimage.io\n",
"from matplotlib import pyplot as plt\n",
"import tifffile\n",
"import timm\n",
"from fastai.vision.all import *\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Hyperparameters\n",
"\n",
"LEARNING_RATE = 1e-4\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"BATCH_SIZE = 1\n",
"NUM_EPOCHS = 4\n",
"NUM_WORKERS = 2\n",
"IMAGE_HEIGHT = 512\n",
"IMAGE_WIDTH = 512\n",
"PIN_MEMORY = True\n",
"LOAD_MODEL = True\n",
"TRAIN_IMG_DIR = \"dat_division/divisionData/train_images/\"\n",
"TRAIN_MASK_DIR = \"dat_division/divisionData/train_masks/\"\n",
"VAL_IMG_DIR = \"dat_division/divisionData/val_images/\"\n",
"VAL_MASK_DIR = \"dat_division/divisionData/val_masks/\"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# defines the dataloader\n",
"\n",
"class DivisionDataset(Dataset):\n",
" def __init__(self, image_dir, mask_dir, transform=None):\n",
" self.image_dir = image_dir\n",
" self.mask_dir = mask_dir\n",
" self.transform = transform\n",
" self.images = os.listdir(image_dir)\n",
"\n",
" def __len__(self):\n",
" return len(self.images)\n",
"\n",
" # gets both the 10 frame images and corresponding mask\n",
" def __getitem__(self, index):\n",
" img_path = os.path.join(self.image_dir, self.images[index])\n",
" mask_path = os.path.join(\n",
" self.mask_dir, self.images[index].replace(\".tif\", \"_mask.tif\"))\n",
" image = sm.io.imread(img_path).astype(np.float32)\n",
" mask = np.array(Image.open(mask_path), dtype=np.float32)\n",
" mask0 = mask\n",
" mask[mask == 255] = 1\n",
" images = torch.tensor(image/256).float()\n",
"\n",
" if self.transform is not None:\n",
" # Normilises and transforms the images and masks \n",
" transformed = self.transform(image=image[0], image0=image[1], image1=image[2], image2=image[3], \n",
" image3=image[4], image4=image[5], image5=image[6], image6=image[7], \n",
" image7=image[8], image8=image[9], mask=mask)\n",
" images[0] = transformed[\"image\"]\n",
" images[1] = transformed[\"image0\"]\n",
" images[2] = transformed[\"image1\"]\n",
" images[3] = transformed[\"image2\"]\n",
" images[4] = transformed[\"image3\"]\n",
" images[5] = transformed[\"image4\"]\n",
" images[6] = transformed[\"image5\"]\n",
" images[7] = transformed[\"image6\"]\n",
" images[8] = transformed[\"image7\"]\n",
" images[9] = transformed[\"image8\"]\n",
"\n",
" mask = transformed[\"mask\"]\n",
"\n",
" # saves the mask and image before and after transform to \n",
" # check transforms are correctly functioning\n",
"\n",
" # save_transform(image, mask0, transformed)\n",
"\n",
" return images, mask\n",
"\n",
"# saves the before and after transform by the augmentations\n",
"def save_transform(image, mask0, transformed):\n",
"\n",
" result = np.zeros([10, 1034, 1034])\n",
" result[:, 0:512, 0:512] = image\n",
" result[0, 0:512, 522:] = np.array(transformed[\"image\"])*255\n",
" result[1, 0:512, 522:] = np.array(transformed[\"image0\"])*255\n",
" result[2, 0:512, 522:] = np.array(transformed[\"image1\"])*255\n",
" result[3, 0:512, 522:] = np.array(transformed[\"image2\"])*255\n",
" result[4, 0:512, 522:] = np.array(transformed[\"image3\"])*255\n",
" result[5, 0:512, 522:] = np.array(transformed[\"image4\"])*255\n",
" result[6, 0:512, 522:] = np.array(transformed[\"image5\"])*255\n",
" result[7, 0:512, 522:] = np.array(transformed[\"image6\"])*255\n",
" result[8, 0:512, 522:] = np.array(transformed[\"image7\"])*255\n",
" result[9, 0:512, 522:] = np.array(transformed[\"image8\"])*255\n",
"\n",
" result[:, 522:, 0:512] = mask0*255\n",
" result[:, 522:, 522:] = np.array(transformed[\"mask\"])*255\n",
"\n",
" result = np.asarray(result, \"uint8\")\n",
" tifffile.imwrite(f\"transformResults/transform.tif\", result)\n",
"\n",
"\n",
"# util\n",
"\n",
"# save model parameters\n",
"def save_checkpoint(state, filename=\"models/UNetCellDivision10.pth.tar\"):\n",
" print(\"=> Saving checkpoint\")\n",
" torch.save(state, filename)\n",
"\n",
"# load model parameters\n",
"def load_checkpoint(checkpoint, model):\n",
" print(\"=> Loading checkpoint\")\n",
" model.load_state_dict(checkpoint[\"state_dict\"])\n",
"\n",
"# Make the dataloader \n",
"def get_loaders(\n",
" train_dir,\n",
" train_maskdir,\n",
" val_dir,\n",
" val_maskdir,\n",
" batch_size,\n",
" train_transform,\n",
" val_transform,\n",
" num_workers=4,\n",
" pin_memory=True\n",
"):\n",
" train_ds = DivisionDataset(\n",
" image_dir=train_dir,\n",
" mask_dir=train_maskdir,\n",
" transform=train_transform,\n",
" )\n",
"\n",
" train_loader = DataLoader(\n",
" train_ds,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" pin_memory=pin_memory,\n",
" shuffle=True,\n",
" )\n",
"\n",
" val_ds = DivisionDataset(\n",
" image_dir=val_dir,\n",
" mask_dir=val_maskdir,\n",
" transform=val_transform,\n",
" )\n",
"\n",
" val_loader = DataLoader(\n",
" val_ds,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" pin_memory=pin_memory,\n",
" shuffle=False\n",
" )\n",
"\n",
" return train_loader, val_loader\n",
"\n",
"# define metric to assess model performance \n",
"def check_accuracy(loader, model, device=\"cuda\"):\n",
" num_correct = 0\n",
" num_pixels = 0\n",
" dice_score = 0\n",
" model.eval()\n",
" loop = tqdm(loader)\n",
"\n",
" with torch.no_grad():\n",
" for batch_idx, (x, y) in enumerate(loop):\n",
" x = x.to(device)\n",
" y = y.to(device).unsqueeze(1)\n",
" preds = torch.sigmoid(model(x))\n",
" preds = (preds > 0.5).float()\n",
" num_correct += (preds == y).sum()\n",
" num_pixels += torch.numel(preds)\n",
" dice_score += (2 * (preds * y).sum()) / (\n",
" (preds + y).sum() + 1e-8\n",
" )\n",
"\n",
" print(\n",
" f\"Accuracy {num_correct/num_pixels*100}%\"\n",
" )\n",
" print(f\"Dice score {dice_score/len(loader)}\")\n",
" model.train()\n",
"\n",
"# saves the ground truth with the model prediciton to folder saved_images\n",
"def save_predictions_as_imgs(loader, model, folder=\"saved_images/\", device=\"cuda\"):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" for idx, (x, y) in enumerate(loader):\n",
" x = x.to(device)\n",
" preds = torch.sigmoid(model(x))\n",
" preds = (preds > 0.5).float()\n",
" for i in range(preds.shape[0]):\n",
" torchvision.utils.save_image(\n",
" preds[i], f\"{folder}pred_{i}.png\"\n",
" )\n",
" torchvision.utils.save_image(\n",
" y.unsqueeze(1)[i], f\"{folder}img_{i}.png\")\n",
"\n",
" break\n",
"\n",
" model.train()\n",
"\n",
"# train the model and update parameters of model\n",
"def train_fn(loader, model, optimizer, loss_fn, scaler):\n",
" loop = tqdm(loader)\n",
"\n",
" for batch_idx, (data, targets) in enumerate(loop):\n",
" data = data.to(device=DEVICE)\n",
" targets = torch.unsqueeze(targets, 1).to(device=DEVICE)\n",
"\n",
" # forward\n",
" with torch.cuda.amp.autocast():\n",
" predictions = model(data)\n",
" loss = loss_fn(predictions, targets)\n",
"\n",
" # backward\n",
" optimizer.zero_grad()\n",
" scaler.scale(loss).backward()\n",
" scaler.step(optimizer)\n",
" scaler.update()\n",
"\n",
" # update tqdm loop\n",
" loop.set_postfix(loss=loss.item())\n",
"\n",
"# load and train the deep learning model\n",
"def main():\n",
" target10 = {'image0': 'image', 'image1': 'image', 'image2': 'image', 'image3': 'image', \n",
" 'image4': 'image', 'image5': 'image', 'image6': 'image', 'image7': 'image', \n",
" 'image8': 'image', 'image9': 'image', 'mask': 'mask'}\n",
" # augmentations for training model\n",
" train_transform = A.Compose(\n",
" [\n",
" A.Rotate(limit=35, p=1.0),\n",
" A.HorizontalFlip(p=0.5),\n",
" A.VerticalFlip(p=0.5),\n",
" A.GaussianBlur(blur_limit=(3, 5), p=0.3),\n",
" A.Normalize(\n",
" mean=0,\n",
" std=1,\n",
" max_pixel_value=255.0,\n",
" ),\n",
" A.RandomBrightnessContrast(p=0.3),\n",
" A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.3),\n",
" ToTensorV2(),\n",
" ],\n",
" additional_targets=target10,\n",
" )\n",
" # augmentations for validation data\n",
" val_transform = A.Compose(\n",
" [\n",
" A.Normalize(\n",
" mean=0,\n",
" std=1,\n",
" max_pixel_value=255.0,\n",
" ),\n",
" ToTensorV2(),\n",
" ],\n",
" additional_targets=target10,\n",
" )\n",
"\n",
" # make the UNetOrientation model \n",
" resnet = timm.create_model(\"resnet34\", pretrained=True)\n",
" resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(\n",
" 7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
"\n",
" m = resnet\n",
" m = nn.Sequential(*list(m.children())[:-2])\n",
" model = DynamicUnet(m, 1, (512, 512), norm_type=None).to(DEVICE)\n",
"\n",
" loss_fn = nn.BCEWithLogitsLoss() # if out_channels > 1 => cross entropy loss\n",
" optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(\n",
" 0.9, 0.999), eps=1e-08)# adam learner\n",
"\n",
" # Make the dataloader \n",
" train_loader, val_loader = get_loaders(\n",
" TRAIN_IMG_DIR,\n",
" TRAIN_MASK_DIR,\n",
" VAL_IMG_DIR,\n",
" VAL_MASK_DIR,\n",
" BATCH_SIZE,\n",
" train_transform, # train_transform\n",
" val_transform, # val_transform\n",
" NUM_WORKERS,\n",
" PIN_MEMORY,\n",
" )\n",
"\n",
" # Load training model if one avalable \n",
" if LOAD_MODEL:\n",
" load_checkpoint(torch.load(\"models/UNetCellDivision10.pth.tar\"), model)\n",
"# save_predictions_as_imgs(\n",
"# val_loader, model, folder=\"saved_images/\", device=DEVICE)\n",
"# check_accuracy(val_loader, model, device=DEVICE)\n",
"\n",
" scaler = torch.cuda.amp.GradScaler()\n",
"\n",
" for epoch in range(NUM_EPOCHS):\n",
" # train model\n",
" train_fn(train_loader, model, optimizer, loss_fn, scaler)\n",
"\n",
" # save model\n",
" checkpoint = {\n",
" \"state_dict\": model.state_dict(),\n",
" \"optimizer\": optimizer.state_dict(),\n",
" }\n",
" save_checkpoint(checkpoint)\n",
"\n",
" # check accuracy\n",
" check_accuracy(val_loader, model, device=DEVICE)\n",
" save_predictions_as_imgs(\n",
" val_loader, model, folder=\"saved_images/\", device=DEVICE)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'timm' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-3-34a49ee8c262>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Train model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mLEARNING_RATE\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5e-5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-2-4927f734137e>\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 229\u001b[0m \u001b[0;31m# make the UNetOrientation model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m \u001b[0mresnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcreate_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"resnet34\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpretrained\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 231\u001b[0m resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(\n\u001b[1;32m 232\u001b[0m 7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
"\u001b[0;31mNameError\u001b[0m: name 'timm' is not defined"
]
}
],
"source": [
"# Train model \n",
"LEARNING_RATE = 5e-5\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/112 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:23<00:00, 1.28s/it, loss=0.0104] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.67132568359375%\n",
"Dice score 0.09293054789304733\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:23<00:00, 1.28s/it, loss=0.00212]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.685791015625%\n",
"Dice score 0.12588801980018616\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:23<00:00, 1.28s/it, loss=0.00666]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.77it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.7662353515625%\n",
"Dice score 0.5328267216682434\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:23<00:00, 1.28s/it, loss=0.00052]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.77it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.78858947753906%\n",
"Dice score 0.5896043181419373\n"
]
}
],
"source": [
"# Train model \n",
"LEARNING_RATE = 5e-5\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/112 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.00219]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.80276489257812%\n",
"Dice score 0.6420930027961731\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.00807]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.80888366699219%\n",
"Dice score 0.6915411353111267\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.28s/it, loss=0.00152]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.81129455566406%\n",
"Dice score 0.6538491249084473\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n",
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.0059] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.82037353515625%\n",
"Dice score 0.7002066969871521\n"
]
}
],
"source": [
"LEARNING_RATE = 1e-5\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/112 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.00187]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.8251724243164%\n",
"Dice score 0.6943625211715698\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.00454]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.77it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.8243408203125%\n",
"Dice score 0.6834194660186768\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.0039] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.77it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.82669067382812%\n",
"Dice score 0.7175348997116089\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.28s/it, loss=0.00666]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.825439453125%\n",
"Dice score 0.6996853947639465\n"
]
}
],
"source": [
"LEARNING_RATE = 5e-6\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/112 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.00398]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.82569122314453%\n",
"Dice score 0.7320536971092224\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.000298]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.77it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.83374786376953%\n",
"Dice score 0.7188710570335388\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/112 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.005] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.83172607421875%\n",
"Dice score 0.7115412354469299\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.005] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.83241271972656%\n",
"Dice score 0.7155706882476807\n"
]
}
],
"source": [
"LEARNING_RATE = 1e-5\n",
"main()\n",
"LEARNING_RATE = 5e-6\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\r\n",
" 0%| | 0/112 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Loading checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.00953]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.78it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.83970642089844%\n",
"Dice score 0.7425869107246399\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.00263]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.76it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.84320068359375%\n",
"Dice score 0.7368522882461548\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.00136]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.75it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.84391021728516%\n",
"Dice score 0.7402313947677612\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 112/112 [02:22<00:00, 1.27s/it, loss=0.000884]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=> Saving checkpoint\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 56/56 [00:20<00:00, 2.77it/s]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy 99.84194946289062%\n",
"Dice score 0.7370697259902954\n"
]
}
],
"source": [
"LEARNING_RATE = 5e-6\n",
"main()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"resnet = timm.create_model(\"resnet34\", pretrained=True)\n",
"# for name, module in resnet.named_modules():\n",
"# print(name)\n",
"# print(resnet.conv1)\n",
"resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(\n",
" 7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
"# print(resnet.conv1)\n",
"\n",
"m = resnet\n",
"m = nn.Sequential(*list(m.children())[:-2])\n",
"model = DynamicUnet(m, 1, (512, 512), norm_type=None).to(DEVICE)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"----------------------------------------------------------------\n",
" Layer (type) Output Shape Param #\n",
"================================================================\n",
" Conv2d-1 [-1, 64, 256, 256] 31,360\n",
" BatchNorm2d-2 [-1, 64, 256, 256] 128\n",
" ReLU-3 [-1, 64, 256, 256] 0\n",
" MaxPool2d-4 [-1, 64, 128, 128] 0\n",
" Conv2d-5 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-6 [-1, 64, 128, 128] 128\n",
" Identity-7 [-1, 64, 128, 128] 0\n",
" ReLU-8 [-1, 64, 128, 128] 0\n",
" Identity-9 [-1, 64, 128, 128] 0\n",
" Conv2d-10 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-11 [-1, 64, 128, 128] 128\n",
" ReLU-12 [-1, 64, 128, 128] 0\n",
" BasicBlock-13 [-1, 64, 128, 128] 0\n",
" Conv2d-14 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-15 [-1, 64, 128, 128] 128\n",
" Identity-16 [-1, 64, 128, 128] 0\n",
" ReLU-17 [-1, 64, 128, 128] 0\n",
" Identity-18 [-1, 64, 128, 128] 0\n",
" Conv2d-19 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-20 [-1, 64, 128, 128] 128\n",
" ReLU-21 [-1, 64, 128, 128] 0\n",
" BasicBlock-22 [-1, 64, 128, 128] 0\n",
" Conv2d-23 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-24 [-1, 64, 128, 128] 128\n",
" Identity-25 [-1, 64, 128, 128] 0\n",
" ReLU-26 [-1, 64, 128, 128] 0\n",
" Identity-27 [-1, 64, 128, 128] 0\n",
" Conv2d-28 [-1, 64, 128, 128] 36,864\n",
" BatchNorm2d-29 [-1, 64, 128, 128] 128\n",
" ReLU-30 [-1, 64, 128, 128] 0\n",
" BasicBlock-31 [-1, 64, 128, 128] 0\n",
" Conv2d-32 [-1, 128, 64, 64] 73,728\n",
" BatchNorm2d-33 [-1, 128, 64, 64] 256\n",
" Identity-34 [-1, 128, 64, 64] 0\n",
" ReLU-35 [-1, 128, 64, 64] 0\n",
" Identity-36 [-1, 128, 64, 64] 0\n",
" Conv2d-37 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-38 [-1, 128, 64, 64] 256\n",
" Conv2d-39 [-1, 128, 64, 64] 8,192\n",
" BatchNorm2d-40 [-1, 128, 64, 64] 256\n",
" ReLU-41 [-1, 128, 64, 64] 0\n",
" BasicBlock-42 [-1, 128, 64, 64] 0\n",
" Conv2d-43 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-44 [-1, 128, 64, 64] 256\n",
" Identity-45 [-1, 128, 64, 64] 0\n",
" ReLU-46 [-1, 128, 64, 64] 0\n",
" Identity-47 [-1, 128, 64, 64] 0\n",
" Conv2d-48 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-49 [-1, 128, 64, 64] 256\n",
" ReLU-50 [-1, 128, 64, 64] 0\n",
" BasicBlock-51 [-1, 128, 64, 64] 0\n",
" Conv2d-52 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-53 [-1, 128, 64, 64] 256\n",
" Identity-54 [-1, 128, 64, 64] 0\n",
" ReLU-55 [-1, 128, 64, 64] 0\n",
" Identity-56 [-1, 128, 64, 64] 0\n",
" Conv2d-57 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-58 [-1, 128, 64, 64] 256\n",
" ReLU-59 [-1, 128, 64, 64] 0\n",
" BasicBlock-60 [-1, 128, 64, 64] 0\n",
" Conv2d-61 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-62 [-1, 128, 64, 64] 256\n",
" Identity-63 [-1, 128, 64, 64] 0\n",
" ReLU-64 [-1, 128, 64, 64] 0\n",
" Identity-65 [-1, 128, 64, 64] 0\n",
" Conv2d-66 [-1, 128, 64, 64] 147,456\n",
" BatchNorm2d-67 [-1, 128, 64, 64] 256\n",
" ReLU-68 [-1, 128, 64, 64] 0\n",
" BasicBlock-69 [-1, 128, 64, 64] 0\n",
" Conv2d-70 [-1, 256, 32, 32] 294,912\n",
" BatchNorm2d-71 [-1, 256, 32, 32] 512\n",
" Identity-72 [-1, 256, 32, 32] 0\n",
" ReLU-73 [-1, 256, 32, 32] 0\n",
" Identity-74 [-1, 256, 32, 32] 0\n",
" Conv2d-75 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-76 [-1, 256, 32, 32] 512\n",
" Conv2d-77 [-1, 256, 32, 32] 32,768\n",
" BatchNorm2d-78 [-1, 256, 32, 32] 512\n",
" ReLU-79 [-1, 256, 32, 32] 0\n",
" BasicBlock-80 [-1, 256, 32, 32] 0\n",
" Conv2d-81 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-82 [-1, 256, 32, 32] 512\n",
" Identity-83 [-1, 256, 32, 32] 0\n",
" ReLU-84 [-1, 256, 32, 32] 0\n",
" Identity-85 [-1, 256, 32, 32] 0\n",
" Conv2d-86 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-87 [-1, 256, 32, 32] 512\n",
" ReLU-88 [-1, 256, 32, 32] 0\n",
" BasicBlock-89 [-1, 256, 32, 32] 0\n",
" Conv2d-90 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-91 [-1, 256, 32, 32] 512\n",
" Identity-92 [-1, 256, 32, 32] 0\n",
" ReLU-93 [-1, 256, 32, 32] 0\n",
" Identity-94 [-1, 256, 32, 32] 0\n",
" Conv2d-95 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-96 [-1, 256, 32, 32] 512\n",
" ReLU-97 [-1, 256, 32, 32] 0\n",
" BasicBlock-98 [-1, 256, 32, 32] 0\n",
" Conv2d-99 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-100 [-1, 256, 32, 32] 512\n",
" Identity-101 [-1, 256, 32, 32] 0\n",
" ReLU-102 [-1, 256, 32, 32] 0\n",
" Identity-103 [-1, 256, 32, 32] 0\n",
" Conv2d-104 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-105 [-1, 256, 32, 32] 512\n",
" ReLU-106 [-1, 256, 32, 32] 0\n",
" BasicBlock-107 [-1, 256, 32, 32] 0\n",
" Conv2d-108 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-109 [-1, 256, 32, 32] 512\n",
" Identity-110 [-1, 256, 32, 32] 0\n",
" ReLU-111 [-1, 256, 32, 32] 0\n",
" Identity-112 [-1, 256, 32, 32] 0\n",
" Conv2d-113 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-114 [-1, 256, 32, 32] 512\n",
" ReLU-115 [-1, 256, 32, 32] 0\n",
" BasicBlock-116 [-1, 256, 32, 32] 0\n",
" Conv2d-117 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-118 [-1, 256, 32, 32] 512\n",
" Identity-119 [-1, 256, 32, 32] 0\n",
" ReLU-120 [-1, 256, 32, 32] 0\n",
" Identity-121 [-1, 256, 32, 32] 0\n",
" Conv2d-122 [-1, 256, 32, 32] 589,824\n",
" BatchNorm2d-123 [-1, 256, 32, 32] 512\n",
" ReLU-124 [-1, 256, 32, 32] 0\n",
" BasicBlock-125 [-1, 256, 32, 32] 0\n",
" Conv2d-126 [-1, 512, 16, 16] 1,179,648\n",
" BatchNorm2d-127 [-1, 512, 16, 16] 1,024\n",
" Identity-128 [-1, 512, 16, 16] 0\n",
" ReLU-129 [-1, 512, 16, 16] 0\n",
" Identity-130 [-1, 512, 16, 16] 0\n",
" Conv2d-131 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-132 [-1, 512, 16, 16] 1,024\n",
" Conv2d-133 [-1, 512, 16, 16] 131,072\n",
" BatchNorm2d-134 [-1, 512, 16, 16] 1,024\n",
" ReLU-135 [-1, 512, 16, 16] 0\n",
" BasicBlock-136 [-1, 512, 16, 16] 0\n",
" Conv2d-137 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-138 [-1, 512, 16, 16] 1,024\n",
" Identity-139 [-1, 512, 16, 16] 0\n",
" ReLU-140 [-1, 512, 16, 16] 0\n",
" Identity-141 [-1, 512, 16, 16] 0\n",
" Conv2d-142 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-143 [-1, 512, 16, 16] 1,024\n",
" ReLU-144 [-1, 512, 16, 16] 0\n",
" BasicBlock-145 [-1, 512, 16, 16] 0\n",
" Conv2d-146 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-147 [-1, 512, 16, 16] 1,024\n",
" Identity-148 [-1, 512, 16, 16] 0\n",
" ReLU-149 [-1, 512, 16, 16] 0\n",
" Identity-150 [-1, 512, 16, 16] 0\n",
" Conv2d-151 [-1, 512, 16, 16] 2,359,296\n",
" BatchNorm2d-152 [-1, 512, 16, 16] 1,024\n",
" ReLU-153 [-1, 512, 16, 16] 0\n",
" BasicBlock-154 [-1, 512, 16, 16] 0\n",
" BatchNorm2d-155 [-1, 512, 16, 16] 1,024\n",
" ReLU-156 [-1, 512, 16, 16] 0\n",
" Conv2d-157 [-1, 1024, 16, 16] 4,719,616\n",
" ReLU-158 [-1, 1024, 16, 16] 0\n",
" Conv2d-159 [-1, 512, 16, 16] 4,719,104\n",
" ReLU-160 [-1, 512, 16, 16] 0\n",
" Conv2d-161 [-1, 1024, 16, 16] 525,312\n",
" ReLU-162 [-1, 1024, 16, 16] 0\n",
" PixelShuffle-163 [-1, 256, 32, 32] 0\n",
" BatchNorm2d-164 [-1, 256, 32, 32] 512\n",
" ReLU-165 [-1, 512, 32, 32] 0\n",
" Conv2d-166 [-1, 512, 32, 32] 2,359,808\n",
" ReLU-167 [-1, 512, 32, 32] 0\n",
" Conv2d-168 [-1, 512, 32, 32] 2,359,808\n",
" ReLU-169 [-1, 512, 32, 32] 0\n",
" UnetBlock-170 [-1, 512, 32, 32] 0\n",
" Conv2d-171 [-1, 1024, 32, 32] 525,312\n",
" ReLU-172 [-1, 1024, 32, 32] 0\n",
" PixelShuffle-173 [-1, 256, 64, 64] 0\n",
" BatchNorm2d-174 [-1, 128, 64, 64] 256\n",
" ReLU-175 [-1, 384, 64, 64] 0\n",
" Conv2d-176 [-1, 384, 64, 64] 1,327,488\n",
" ReLU-177 [-1, 384, 64, 64] 0\n",
" Conv2d-178 [-1, 384, 64, 64] 1,327,488\n",
" ReLU-179 [-1, 384, 64, 64] 0\n",
" UnetBlock-180 [-1, 384, 64, 64] 0\n",
" Conv2d-181 [-1, 768, 64, 64] 295,680\n",
" ReLU-182 [-1, 768, 64, 64] 0\n",
" PixelShuffle-183 [-1, 192, 128, 128] 0\n",
" BatchNorm2d-184 [-1, 64, 128, 128] 128\n",
" ReLU-185 [-1, 256, 128, 128] 0\n",
" Conv2d-186 [-1, 256, 128, 128] 590,080\n",
" ReLU-187 [-1, 256, 128, 128] 0\n",
" Conv2d-188 [-1, 256, 128, 128] 590,080\n",
" ReLU-189 [-1, 256, 128, 128] 0\n",
" UnetBlock-190 [-1, 256, 128, 128] 0\n",
" Conv2d-191 [-1, 512, 128, 128] 131,584\n",
" ReLU-192 [-1, 512, 128, 128] 0\n",
" PixelShuffle-193 [-1, 128, 256, 256] 0\n",
" BatchNorm2d-194 [-1, 64, 256, 256] 128\n",
" ReLU-195 [-1, 192, 256, 256] 0\n",
" Conv2d-196 [-1, 96, 256, 256] 165,984\n",
" ReLU-197 [-1, 96, 256, 256] 0\n",
" Conv2d-198 [-1, 96, 256, 256] 83,040\n",
" ReLU-199 [-1, 96, 256, 256] 0\n",
" UnetBlock-200 [-1, 96, 256, 256] 0\n",
" Conv2d-201 [-1, 384, 256, 256] 37,248\n",
" ReLU-202 [-1, 384, 256, 256] 0\n",
" PixelShuffle-203 [-1, 96, 512, 512] 0\n",
" ResizeToOrig-204 [-1, 96, 512, 512] 0\n",
" MergeLayer-205 [-1, 106, 512, 512] 0\n",
" Conv2d-206 [-1, 106, 512, 512] 101,230\n",
" ReLU-207 [-1, 106, 512, 512] 0\n",
" Conv2d-208 [-1, 106, 512, 512] 101,230\n",
" ReLU-209 [-1, 106, 512, 512] 0\n",
" ResBlock-210 [-1, 106, 512, 512] 0\n",
" Conv2d-211 [-1, 1, 512, 512] 107\n",
" ToTensorBase-212 [-1, 1, 512, 512] 0\n",
"================================================================\n",
"Total params: 41,268,871\n",
"Trainable params: 41,268,871\n",
"Non-trainable params: 0\n",
"----------------------------------------------------------------\n",
"Input size (MB): 10.00\n",
"Forward/backward pass size (MB): 3629.00\n",
"Params size (MB): 157.43\n",
"Estimated Total Size (MB): 3796.43\n",
"----------------------------------------------------------------\n"
]
}
],
"source": [
"from torchsummary import summary\n",
"summary(model, (10, 512, 512))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.6"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
