Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

https://github.com/turleyjm/cell-division-dl-plugin
24 August 2024, 11:55:15 UTC
  • Code
  • Branches (1)
  • Releases (0)
  • Visits
    • Branches
    • Releases
    • HEAD
    • refs/heads/main
    • f6c3290670f23524d47ccddc873ef5673eda4b3c
    No releases to show
  • 828bb9e
  • /
  • trainModels
  • /
  • trainingUNetCellDivision10.ipynb
Raw File Download Save again
Take a new snapshot of a software origin

If the archived software origin currently browsed is not synchronized with its upstream version (for instance when new commits have been issued), you can explicitly request Software Heritage to take a new snapshot of it.

Use the form below to proceed. Once a request has been submitted and accepted, it will be processed as soon as possible. You can then check its processing state by visiting this dedicated page.
swh spinner

Processing "take a new snapshot" request ...

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
  • revision
  • snapshot
origin badgecontent badge
swh:1:cnt:87806507acae27db289e98400b198dfaf1698cc7
origin badgedirectory badge
swh:1:dir:c0e8faef06e27068019965a2e73727b0a40d168f
origin badgerevision badge
swh:1:rev:f6c3290670f23524d47ccddc873ef5673eda4b3c
origin badgesnapshot badge
swh:1:snp:939fe1ceeff9f734ad605b447dd8ed379af63455

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
  • revision
  • snapshot
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
(requires biblatex-software package)
Generating citation ...
Tip revision: f6c3290670f23524d47ccddc873ef5673eda4b3c authored by turleyjm on 11 July 2024, 02:09:47 UTC
Update README.md
Tip revision: f6c3290
trainingUNetCellDivision10.ipynb
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'wandb.proto.wandb_internal_pb2' has no attribute 'Result'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-1-2557cbe562cb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mmatplotlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     17\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtifffile\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mfastai\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvision\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mversion\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0m__version__\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mis_scriptable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_exportable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mset_scriptable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mset_exportable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist_models\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist_pretrained\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist_modules\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_entrypoint\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m     \u001b[0mis_model_pretrained\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_pretrained_cfg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_pretrained_cfg_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/models/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbeit\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbyoanet\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mbyobnet\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcait\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mcoat\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/models/beit.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     47\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheckpoint\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcheckpoint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIMAGENET_DEFAULT_MEAN\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIMAGENET_DEFAULT_STD\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mPatchEmbed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMlp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSwiGLU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLayerNorm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mDropPath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrunc_normal_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muse_fused_attn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mresample_patch_embed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresample_abs_pos_embed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresize_rel_pos_bias_table\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mresolve_data_config\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresolve_model_data_config\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mconstants\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImageDataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mIterableImageDataset\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mAugMixDataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdataset_factory\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mdataset_info\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDatasetInfo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCustomDatasetInfo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/dataset.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     11\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mPIL\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 13\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mreaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_reader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0m_logger\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetLogger\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/readers/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mreader_factory\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mcreate_reader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mimg_extensions\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/readers/reader_factory.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mreader_image_folder\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mReaderImageFolder\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mreader_image_in_tar\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mReaderImageInTar\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/data/readers/reader_image_folder.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtyping\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDict\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSet\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTuple\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mUnion\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 11\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmisc\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnatural_key\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     12\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mclass_map\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mload_class_map\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/utils/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmodel_ema\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mModelEma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mModelEmaV2\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     14\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mrandom\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mrandom_seed\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0msummary\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mupdate_summary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_outdir\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/timm/utils/summary.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mcollections\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mOrderedDict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m     \u001b[0;32mimport\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mImportError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0merrors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mterm\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtermsetup\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtermlog\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtermerror\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtermwarn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msdk\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mwandb_sdk\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     24\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mwandb_helper\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mhelper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0martifacts\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0martifact\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mArtifact\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     26\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mwandb_alerts\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAlertLevel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mwandb_config\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mConfig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/artifacts/artifact.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     35\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     36\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdata_types\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mutil\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 37\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnormalize\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnormalize_exceptions\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     38\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpublic\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mArtifactCollection\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mArtifactFiles\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRetryingClient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mRun\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     39\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_types\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWBValue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/apis/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     41\u001b[0m \u001b[0mreset_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvendor_setup\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     42\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0minternal\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mApi\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mInternalApi\u001b[0m  \u001b[0;31m# noqa\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     44\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mpublic\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mApi\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mPublicApi\u001b[0m  \u001b[0;31m# noqa\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     45\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/apis/internal.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtyping\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msdk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minternal\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minternal_api\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mApi\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mInternalApi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/internal/internal_api.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     46\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msdk\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhashutil\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mB64MD5\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmd5_file_b64\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     47\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mretry\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     49\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilenames\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDIFF_FNAME\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mMETADATA_FNAME\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     50\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgitlib\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mGitRepo\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/lib/retry.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutil\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCheckRetryFnType\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mmailbox\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mContextCancelledError\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0mlogger\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlogging\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetLogger\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/lib/mailbox.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m    100\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 102\u001b[0;31m \u001b[0;32mclass\u001b[0m \u001b[0m_MailboxSlot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    103\u001b[0m     \u001b[0m_result\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mResult\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m     \u001b[0m_event\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthreading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEvent\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/envs/fastai/lib/python3.8/site-packages/wandb/sdk/lib/mailbox.py\u001b[0m in \u001b[0;36m_MailboxSlot\u001b[0;34m()\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0m_MailboxSlot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m     \u001b[0m_result\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mpb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mResult\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    104\u001b[0m     \u001b[0m_event\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthreading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mEvent\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    105\u001b[0m     \u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mthreading\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLock\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mAttributeError\u001b[0m: module 'wandb.proto.wandb_internal_pb2' has no attribute 'Result'"
     ]
    }
   ],
   "source": [
    "from cgitb import reset\n",
    "import torch\n",
    "import albumentations as A\n",
    "from albumentations.pytorch import ToTensorV2\n",
    "from tqdm import tqdm\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torchvision\n",
    "from torch.utils.data import DataLoader\n",
    "import os\n",
    "from PIL import Image\n",
    "from torch.utils.data import Dataset\n",
    "import numpy as np\n",
    "import skimage as sm\n",
    "import skimage.io\n",
    "from matplotlib import pyplot as plt\n",
    "import tifffile\n",
    "import timm\n",
    "from fastai.vision.all import *\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "\n",
    "LEARNING_RATE = 1e-4\n",
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "BATCH_SIZE = 1\n",
    "NUM_EPOCHS = 4\n",
    "NUM_WORKERS = 2\n",
    "IMAGE_HEIGHT = 512\n",
    "IMAGE_WIDTH = 512\n",
    "PIN_MEMORY = True\n",
    "LOAD_MODEL = True\n",
    "TRAIN_IMG_DIR = \"dat_division/divisionData/train_images/\"\n",
    "TRAIN_MASK_DIR = \"dat_division/divisionData/train_masks/\"\n",
    "VAL_IMG_DIR = \"dat_division/divisionData/val_images/\"\n",
    "VAL_MASK_DIR = \"dat_division/divisionData/val_masks/\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# defines the dataloader\n",
    "\n",
    "class DivisionDataset(Dataset):\n",
    "    def __init__(self, image_dir, mask_dir, transform=None):\n",
    "        self.image_dir = image_dir\n",
    "        self.mask_dir = mask_dir\n",
    "        self.transform = transform\n",
    "        self.images = os.listdir(image_dir)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.images)\n",
    "\n",
    "    # gets both the 10 frame images and corresponding mask\n",
    "    def __getitem__(self, index):\n",
    "        img_path = os.path.join(self.image_dir, self.images[index])\n",
    "        mask_path = os.path.join(\n",
    "            self.mask_dir, self.images[index].replace(\".tif\", \"_mask.tif\"))\n",
    "        image = sm.io.imread(img_path).astype(np.float32)\n",
    "        mask = np.array(Image.open(mask_path), dtype=np.float32)\n",
    "        mask0 = mask\n",
    "        mask[mask == 255] = 1\n",
    "        images = torch.tensor(image/256).float()\n",
    "\n",
    "        if self.transform is not None:\n",
    "            # Normilises and transforms the images and masks \n",
    "            transformed = self.transform(image=image[0], image0=image[1], image1=image[2], image2=image[3], \n",
    "                                         image3=image[4], image4=image[5], image5=image[6], image6=image[7], \n",
    "                                         image7=image[8], image8=image[9], mask=mask)\n",
    "            images[0] = transformed[\"image\"]\n",
    "            images[1] = transformed[\"image0\"]\n",
    "            images[2] = transformed[\"image1\"]\n",
    "            images[3] = transformed[\"image2\"]\n",
    "            images[4] = transformed[\"image3\"]\n",
    "            images[5] = transformed[\"image4\"]\n",
    "            images[6] = transformed[\"image5\"]\n",
    "            images[7] = transformed[\"image6\"]\n",
    "            images[8] = transformed[\"image7\"]\n",
    "            images[9] = transformed[\"image8\"]\n",
    "\n",
    "            mask = transformed[\"mask\"]\n",
    "\n",
    "            # saves the mask and image before and after transform to \n",
    "            # check transforms are correctly functioning\n",
    "\n",
    "            # save_transform(image, mask0, transformed)\n",
    "\n",
    "        return images, mask\n",
    "\n",
    "# saves the before and after transform by the augmentations\n",
    "def save_transform(image, mask0, transformed):\n",
    "\n",
    "    result = np.zeros([10, 1034, 1034])\n",
    "    result[:, 0:512, 0:512] = image\n",
    "    result[0, 0:512, 522:] = np.array(transformed[\"image\"])*255\n",
    "    result[1, 0:512, 522:] = np.array(transformed[\"image0\"])*255\n",
    "    result[2, 0:512, 522:] = np.array(transformed[\"image1\"])*255\n",
    "    result[3, 0:512, 522:] = np.array(transformed[\"image2\"])*255\n",
    "    result[4, 0:512, 522:] = np.array(transformed[\"image3\"])*255\n",
    "    result[5, 0:512, 522:] = np.array(transformed[\"image4\"])*255\n",
    "    result[6, 0:512, 522:] = np.array(transformed[\"image5\"])*255\n",
    "    result[7, 0:512, 522:] = np.array(transformed[\"image6\"])*255\n",
    "    result[8, 0:512, 522:] = np.array(transformed[\"image7\"])*255\n",
    "    result[9, 0:512, 522:] = np.array(transformed[\"image8\"])*255\n",
    "\n",
    "    result[:, 522:, 0:512] = mask0*255\n",
    "    result[:, 522:, 522:] = np.array(transformed[\"mask\"])*255\n",
    "\n",
    "    result = np.asarray(result, \"uint8\")\n",
    "    tifffile.imwrite(f\"transformResults/transform.tif\", result)\n",
    "\n",
    "\n",
    "# util\n",
    "\n",
    "# save model parameters\n",
    "def save_checkpoint(state, filename=\"models/UNetCellDivision10.pth.tar\"):\n",
    "    print(\"=> Saving checkpoint\")\n",
    "    torch.save(state, filename)\n",
    "\n",
    "# load model parameters\n",
    "def load_checkpoint(checkpoint, model):\n",
    "    print(\"=> Loading checkpoint\")\n",
    "    model.load_state_dict(checkpoint[\"state_dict\"])\n",
    "\n",
    "# Make the dataloader \n",
    "def get_loaders(\n",
    "    train_dir,\n",
    "    train_maskdir,\n",
    "    val_dir,\n",
    "    val_maskdir,\n",
    "    batch_size,\n",
    "    train_transform,\n",
    "    val_transform,\n",
    "    num_workers=4,\n",
    "    pin_memory=True\n",
    "):\n",
    "    train_ds = DivisionDataset(\n",
    "        image_dir=train_dir,\n",
    "        mask_dir=train_maskdir,\n",
    "        transform=train_transform,\n",
    "    )\n",
    "\n",
    "    train_loader = DataLoader(\n",
    "        train_ds,\n",
    "        batch_size=batch_size,\n",
    "        num_workers=num_workers,\n",
    "        pin_memory=pin_memory,\n",
    "        shuffle=True,\n",
    "    )\n",
    "\n",
    "    val_ds = DivisionDataset(\n",
    "        image_dir=val_dir,\n",
    "        mask_dir=val_maskdir,\n",
    "        transform=val_transform,\n",
    "    )\n",
    "\n",
    "    val_loader = DataLoader(\n",
    "        val_ds,\n",
    "        batch_size=batch_size,\n",
    "        num_workers=num_workers,\n",
    "        pin_memory=pin_memory,\n",
    "        shuffle=False\n",
    "    )\n",
    "\n",
    "    return train_loader, val_loader\n",
    "\n",
    "# define metric to assess model performance \n",
    "def check_accuracy(loader, model, device=\"cuda\"):\n",
    "    num_correct = 0\n",
    "    num_pixels = 0\n",
    "    dice_score = 0\n",
    "    model.eval()\n",
    "    loop = tqdm(loader)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch_idx, (x, y) in enumerate(loop):\n",
    "            x = x.to(device)\n",
    "            y = y.to(device).unsqueeze(1)\n",
    "            preds = torch.sigmoid(model(x))\n",
    "            preds = (preds > 0.5).float()\n",
    "            num_correct += (preds == y).sum()\n",
    "            num_pixels += torch.numel(preds)\n",
    "            dice_score += (2 * (preds * y).sum()) / (\n",
    "                (preds + y).sum() + 1e-8\n",
    "            )\n",
    "\n",
    "    print(\n",
    "        f\"Accuracy {num_correct/num_pixels*100}%\"\n",
    "    )\n",
    "    print(f\"Dice score {dice_score/len(loader)}\")\n",
    "    model.train()\n",
    "\n",
    "# saves the ground truth with the model prediciton to folder saved_images\n",
    "def save_predictions_as_imgs(loader, model, folder=\"saved_images/\", device=\"cuda\"):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        for idx, (x, y) in enumerate(loader):\n",
    "            x = x.to(device)\n",
    "            preds = torch.sigmoid(model(x))\n",
    "            preds = (preds > 0.5).float()\n",
    "            for i in range(preds.shape[0]):\n",
    "                torchvision.utils.save_image(\n",
    "                    preds[i], f\"{folder}pred_{i}.png\"\n",
    "                )\n",
    "                torchvision.utils.save_image(\n",
    "                    y.unsqueeze(1)[i], f\"{folder}img_{i}.png\")\n",
    "\n",
    "            break\n",
    "\n",
    "    model.train()\n",
    "\n",
    "# train the model and update parameters of model\n",
    "def train_fn(loader, model, optimizer, loss_fn, scaler):\n",
    "    loop = tqdm(loader)\n",
    "\n",
    "    for batch_idx, (data, targets) in enumerate(loop):\n",
    "        data = data.to(device=DEVICE)\n",
    "        targets = torch.unsqueeze(targets, 1).to(device=DEVICE)\n",
    "\n",
    "        # forward\n",
    "        with torch.cuda.amp.autocast():\n",
    "            predictions = model(data)\n",
    "            loss = loss_fn(predictions, targets)\n",
    "\n",
    "        # backward\n",
    "        optimizer.zero_grad()\n",
    "        scaler.scale(loss).backward()\n",
    "        scaler.step(optimizer)\n",
    "        scaler.update()\n",
    "\n",
    "        # update tqdm loop\n",
    "        loop.set_postfix(loss=loss.item())\n",
    "\n",
    "# load and train the deep learning model\n",
    "def main():\n",
    "    target10 = {'image0': 'image', 'image1': 'image', 'image2': 'image', 'image3': 'image', \n",
    "                'image4': 'image', 'image5': 'image', 'image6': 'image', 'image7': 'image', \n",
    "                'image8': 'image', 'image9': 'image', 'mask': 'mask'}\n",
    "    # augmentations for training model\n",
    "    train_transform = A.Compose(\n",
    "        [\n",
    "            A.Rotate(limit=35, p=1.0),\n",
    "            A.HorizontalFlip(p=0.5),\n",
    "            A.VerticalFlip(p=0.5),\n",
    "            A.GaussianBlur(blur_limit=(3, 5), p=0.3),\n",
    "            A.Normalize(\n",
    "                mean=0,\n",
    "                std=1,\n",
    "                max_pixel_value=255.0,\n",
    "            ),\n",
    "            A.RandomBrightnessContrast(p=0.3),\n",
    "            A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.3),\n",
    "            ToTensorV2(),\n",
    "        ],\n",
    "        additional_targets=target10,\n",
    "    )\n",
    "    # augmentations for validation data\n",
    "    val_transform = A.Compose(\n",
    "        [\n",
    "            A.Normalize(\n",
    "                mean=0,\n",
    "                std=1,\n",
    "                max_pixel_value=255.0,\n",
    "            ),\n",
    "            ToTensorV2(),\n",
    "        ],\n",
    "        additional_targets=target10,\n",
    "    )\n",
    "\n",
    "    # make the UNetOrientation model \n",
    "    resnet = timm.create_model(\"resnet34\", pretrained=True)\n",
    "    resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(\n",
    "        7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
    "\n",
    "    m = resnet\n",
    "    m = nn.Sequential(*list(m.children())[:-2])\n",
    "    model = DynamicUnet(m, 1, (512, 512), norm_type=None).to(DEVICE)\n",
    "\n",
    "    loss_fn = nn.BCEWithLogitsLoss()  # if out_channels > 1 => cross entropy loss\n",
    "    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(\n",
    "        0.9, 0.999), eps=1e-08)# adam learner\n",
    "\n",
    "    # Make the dataloader \n",
    "    train_loader, val_loader = get_loaders(\n",
    "        TRAIN_IMG_DIR,\n",
    "        TRAIN_MASK_DIR,\n",
    "        VAL_IMG_DIR,\n",
    "        VAL_MASK_DIR,\n",
    "        BATCH_SIZE,\n",
    "        train_transform,  # train_transform\n",
    "        val_transform,  # val_transform\n",
    "        NUM_WORKERS,\n",
    "        PIN_MEMORY,\n",
    "    )\n",
    "\n",
    "    # Load training model if one avalable \n",
    "    if LOAD_MODEL:\n",
    "        load_checkpoint(torch.load(\"models/UNetCellDivision10.pth.tar\"), model)\n",
    "#         save_predictions_as_imgs(\n",
    "#             val_loader, model, folder=\"saved_images/\", device=DEVICE)\n",
    "#         check_accuracy(val_loader, model, device=DEVICE)\n",
    "\n",
    "    scaler = torch.cuda.amp.GradScaler()\n",
    "\n",
    "    for epoch in range(NUM_EPOCHS):\n",
    "        # train model\n",
    "        train_fn(train_loader, model, optimizer, loss_fn, scaler)\n",
    "\n",
    "        # save model\n",
    "        checkpoint = {\n",
    "            \"state_dict\": model.state_dict(),\n",
    "            \"optimizer\": optimizer.state_dict(),\n",
    "        }\n",
    "        save_checkpoint(checkpoint)\n",
    "\n",
    "        # check accuracy\n",
    "        check_accuracy(val_loader, model, device=DEVICE)\n",
    "        save_predictions_as_imgs(\n",
    "            val_loader, model, folder=\"saved_images/\", device=DEVICE)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'timm' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-3-34a49ee8c262>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;31m# Train model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0mLEARNING_RATE\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5e-5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mmain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m<ipython-input-2-4927f734137e>\u001b[0m in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m    228\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    229\u001b[0m     \u001b[0;31m# make the UNetOrientation model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 230\u001b[0;31m     \u001b[0mresnet\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtimm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcreate_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"resnet34\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpretrained\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    231\u001b[0m     resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(\n\u001b[1;32m    232\u001b[0m         7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'timm' is not defined"
     ]
    }
   ],
   "source": [
    "# Train model \n",
    "LEARNING_RATE = 5e-5\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "  0%|          | 0/112 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Loading checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=0.0104] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.67132568359375%\n",
      "Dice score 0.09293054789304733\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=0.00212]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.685791015625%\n",
      "Dice score 0.12588801980018616\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=0.00666]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.7662353515625%\n",
      "Dice score 0.5328267216682434\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:23<00:00,  1.28s/it, loss=0.00052]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.78858947753906%\n",
      "Dice score 0.5896043181419373\n"
     ]
    }
   ],
   "source": [
    "# Train model \n",
    "LEARNING_RATE = 5e-5\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "  0%|          | 0/112 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Loading checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00219]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.80276489257812%\n",
      "Dice score 0.6420930027961731\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00807]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.80888366699219%\n",
      "Dice score 0.6915411353111267\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.28s/it, loss=0.00152]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.81129455566406%\n",
      "Dice score 0.6538491249084473\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.0059] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.82037353515625%\n",
      "Dice score 0.7002066969871521\n"
     ]
    }
   ],
   "source": [
    "LEARNING_RATE = 1e-5\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "  0%|          | 0/112 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Loading checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00187]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.8251724243164%\n",
      "Dice score 0.6943625211715698\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00454]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.8243408203125%\n",
      "Dice score 0.6834194660186768\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.0039] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.82669067382812%\n",
      "Dice score 0.7175348997116089\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.28s/it, loss=0.00666]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.825439453125%\n",
      "Dice score 0.6996853947639465\n"
     ]
    }
   ],
   "source": [
    "LEARNING_RATE = 5e-6\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "  0%|          | 0/112 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Loading checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00398]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.82569122314453%\n",
      "Dice score 0.7320536971092224\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.000298]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.83374786376953%\n",
      "Dice score 0.7188710570335388\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "  0%|          | 0/112 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Loading checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.005]  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.83172607421875%\n",
      "Dice score 0.7115412354469299\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.005]  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.83241271972656%\n",
      "Dice score 0.7155706882476807\n"
     ]
    }
   ],
   "source": [
    "LEARNING_RATE = 1e-5\n",
    "main()\n",
    "LEARNING_RATE = 5e-6\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r\n",
      "  0%|          | 0/112 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Loading checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00953]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.83970642089844%\n",
      "Dice score 0.7425869107246399\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00263]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.84320068359375%\n",
      "Dice score 0.7368522882461548\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.00136]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.84391021728516%\n",
      "Dice score 0.7402313947677612\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 112/112 [02:22<00:00,  1.27s/it, loss=0.000884]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=> Saving checkpoint\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 56/56 [00:20<00:00,  2.77it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 99.84194946289062%\n",
      "Dice score 0.7370697259902954\n"
     ]
    }
   ],
   "source": [
    "LEARNING_RATE = 5e-6\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "resnet = timm.create_model(\"resnet34\", pretrained=True)\n",
    "# for name, module in resnet.named_modules():\n",
    "#     print(name)\n",
    "# print(resnet.conv1)\n",
    "resnet.conv1 = nn.Conv2d(10, 64, kernel_size=(\n",
    "    7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
    "# print(resnet.conv1)\n",
    "\n",
    "m = resnet\n",
    "m = nn.Sequential(*list(m.children())[:-2])\n",
    "model = DynamicUnet(m, 1, (512, 512), norm_type=None).to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------------------------------------------------------------\n",
      "        Layer (type)               Output Shape         Param #\n",
      "================================================================\n",
      "            Conv2d-1         [-1, 64, 256, 256]          31,360\n",
      "       BatchNorm2d-2         [-1, 64, 256, 256]             128\n",
      "              ReLU-3         [-1, 64, 256, 256]               0\n",
      "         MaxPool2d-4         [-1, 64, 128, 128]               0\n",
      "            Conv2d-5         [-1, 64, 128, 128]          36,864\n",
      "       BatchNorm2d-6         [-1, 64, 128, 128]             128\n",
      "          Identity-7         [-1, 64, 128, 128]               0\n",
      "              ReLU-8         [-1, 64, 128, 128]               0\n",
      "          Identity-9         [-1, 64, 128, 128]               0\n",
      "           Conv2d-10         [-1, 64, 128, 128]          36,864\n",
      "      BatchNorm2d-11         [-1, 64, 128, 128]             128\n",
      "             ReLU-12         [-1, 64, 128, 128]               0\n",
      "       BasicBlock-13         [-1, 64, 128, 128]               0\n",
      "           Conv2d-14         [-1, 64, 128, 128]          36,864\n",
      "      BatchNorm2d-15         [-1, 64, 128, 128]             128\n",
      "         Identity-16         [-1, 64, 128, 128]               0\n",
      "             ReLU-17         [-1, 64, 128, 128]               0\n",
      "         Identity-18         [-1, 64, 128, 128]               0\n",
      "           Conv2d-19         [-1, 64, 128, 128]          36,864\n",
      "      BatchNorm2d-20         [-1, 64, 128, 128]             128\n",
      "             ReLU-21         [-1, 64, 128, 128]               0\n",
      "       BasicBlock-22         [-1, 64, 128, 128]               0\n",
      "           Conv2d-23         [-1, 64, 128, 128]          36,864\n",
      "      BatchNorm2d-24         [-1, 64, 128, 128]             128\n",
      "         Identity-25         [-1, 64, 128, 128]               0\n",
      "             ReLU-26         [-1, 64, 128, 128]               0\n",
      "         Identity-27         [-1, 64, 128, 128]               0\n",
      "           Conv2d-28         [-1, 64, 128, 128]          36,864\n",
      "      BatchNorm2d-29         [-1, 64, 128, 128]             128\n",
      "             ReLU-30         [-1, 64, 128, 128]               0\n",
      "       BasicBlock-31         [-1, 64, 128, 128]               0\n",
      "           Conv2d-32          [-1, 128, 64, 64]          73,728\n",
      "      BatchNorm2d-33          [-1, 128, 64, 64]             256\n",
      "         Identity-34          [-1, 128, 64, 64]               0\n",
      "             ReLU-35          [-1, 128, 64, 64]               0\n",
      "         Identity-36          [-1, 128, 64, 64]               0\n",
      "           Conv2d-37          [-1, 128, 64, 64]         147,456\n",
      "      BatchNorm2d-38          [-1, 128, 64, 64]             256\n",
      "           Conv2d-39          [-1, 128, 64, 64]           8,192\n",
      "      BatchNorm2d-40          [-1, 128, 64, 64]             256\n",
      "             ReLU-41          [-1, 128, 64, 64]               0\n",
      "       BasicBlock-42          [-1, 128, 64, 64]               0\n",
      "           Conv2d-43          [-1, 128, 64, 64]         147,456\n",
      "      BatchNorm2d-44          [-1, 128, 64, 64]             256\n",
      "         Identity-45          [-1, 128, 64, 64]               0\n",
      "             ReLU-46          [-1, 128, 64, 64]               0\n",
      "         Identity-47          [-1, 128, 64, 64]               0\n",
      "           Conv2d-48          [-1, 128, 64, 64]         147,456\n",
      "      BatchNorm2d-49          [-1, 128, 64, 64]             256\n",
      "             ReLU-50          [-1, 128, 64, 64]               0\n",
      "       BasicBlock-51          [-1, 128, 64, 64]               0\n",
      "           Conv2d-52          [-1, 128, 64, 64]         147,456\n",
      "      BatchNorm2d-53          [-1, 128, 64, 64]             256\n",
      "         Identity-54          [-1, 128, 64, 64]               0\n",
      "             ReLU-55          [-1, 128, 64, 64]               0\n",
      "         Identity-56          [-1, 128, 64, 64]               0\n",
      "           Conv2d-57          [-1, 128, 64, 64]         147,456\n",
      "      BatchNorm2d-58          [-1, 128, 64, 64]             256\n",
      "             ReLU-59          [-1, 128, 64, 64]               0\n",
      "       BasicBlock-60          [-1, 128, 64, 64]               0\n",
      "           Conv2d-61          [-1, 128, 64, 64]         147,456\n",
      "      BatchNorm2d-62          [-1, 128, 64, 64]             256\n",
      "         Identity-63          [-1, 128, 64, 64]               0\n",
      "             ReLU-64          [-1, 128, 64, 64]               0\n",
      "         Identity-65          [-1, 128, 64, 64]               0\n",
      "           Conv2d-66          [-1, 128, 64, 64]         147,456\n",
      "      BatchNorm2d-67          [-1, 128, 64, 64]             256\n",
      "             ReLU-68          [-1, 128, 64, 64]               0\n",
      "       BasicBlock-69          [-1, 128, 64, 64]               0\n",
      "           Conv2d-70          [-1, 256, 32, 32]         294,912\n",
      "      BatchNorm2d-71          [-1, 256, 32, 32]             512\n",
      "         Identity-72          [-1, 256, 32, 32]               0\n",
      "             ReLU-73          [-1, 256, 32, 32]               0\n",
      "         Identity-74          [-1, 256, 32, 32]               0\n",
      "           Conv2d-75          [-1, 256, 32, 32]         589,824\n",
      "      BatchNorm2d-76          [-1, 256, 32, 32]             512\n",
      "           Conv2d-77          [-1, 256, 32, 32]          32,768\n",
      "      BatchNorm2d-78          [-1, 256, 32, 32]             512\n",
      "             ReLU-79          [-1, 256, 32, 32]               0\n",
      "       BasicBlock-80          [-1, 256, 32, 32]               0\n",
      "           Conv2d-81          [-1, 256, 32, 32]         589,824\n",
      "      BatchNorm2d-82          [-1, 256, 32, 32]             512\n",
      "         Identity-83          [-1, 256, 32, 32]               0\n",
      "             ReLU-84          [-1, 256, 32, 32]               0\n",
      "         Identity-85          [-1, 256, 32, 32]               0\n",
      "           Conv2d-86          [-1, 256, 32, 32]         589,824\n",
      "      BatchNorm2d-87          [-1, 256, 32, 32]             512\n",
      "             ReLU-88          [-1, 256, 32, 32]               0\n",
      "       BasicBlock-89          [-1, 256, 32, 32]               0\n",
      "           Conv2d-90          [-1, 256, 32, 32]         589,824\n",
      "      BatchNorm2d-91          [-1, 256, 32, 32]             512\n",
      "         Identity-92          [-1, 256, 32, 32]               0\n",
      "             ReLU-93          [-1, 256, 32, 32]               0\n",
      "         Identity-94          [-1, 256, 32, 32]               0\n",
      "           Conv2d-95          [-1, 256, 32, 32]         589,824\n",
      "      BatchNorm2d-96          [-1, 256, 32, 32]             512\n",
      "             ReLU-97          [-1, 256, 32, 32]               0\n",
      "       BasicBlock-98          [-1, 256, 32, 32]               0\n",
      "           Conv2d-99          [-1, 256, 32, 32]         589,824\n",
      "     BatchNorm2d-100          [-1, 256, 32, 32]             512\n",
      "        Identity-101          [-1, 256, 32, 32]               0\n",
      "            ReLU-102          [-1, 256, 32, 32]               0\n",
      "        Identity-103          [-1, 256, 32, 32]               0\n",
      "          Conv2d-104          [-1, 256, 32, 32]         589,824\n",
      "     BatchNorm2d-105          [-1, 256, 32, 32]             512\n",
      "            ReLU-106          [-1, 256, 32, 32]               0\n",
      "      BasicBlock-107          [-1, 256, 32, 32]               0\n",
      "          Conv2d-108          [-1, 256, 32, 32]         589,824\n",
      "     BatchNorm2d-109          [-1, 256, 32, 32]             512\n",
      "        Identity-110          [-1, 256, 32, 32]               0\n",
      "            ReLU-111          [-1, 256, 32, 32]               0\n",
      "        Identity-112          [-1, 256, 32, 32]               0\n",
      "          Conv2d-113          [-1, 256, 32, 32]         589,824\n",
      "     BatchNorm2d-114          [-1, 256, 32, 32]             512\n",
      "            ReLU-115          [-1, 256, 32, 32]               0\n",
      "      BasicBlock-116          [-1, 256, 32, 32]               0\n",
      "          Conv2d-117          [-1, 256, 32, 32]         589,824\n",
      "     BatchNorm2d-118          [-1, 256, 32, 32]             512\n",
      "        Identity-119          [-1, 256, 32, 32]               0\n",
      "            ReLU-120          [-1, 256, 32, 32]               0\n",
      "        Identity-121          [-1, 256, 32, 32]               0\n",
      "          Conv2d-122          [-1, 256, 32, 32]         589,824\n",
      "     BatchNorm2d-123          [-1, 256, 32, 32]             512\n",
      "            ReLU-124          [-1, 256, 32, 32]               0\n",
      "      BasicBlock-125          [-1, 256, 32, 32]               0\n",
      "          Conv2d-126          [-1, 512, 16, 16]       1,179,648\n",
      "     BatchNorm2d-127          [-1, 512, 16, 16]           1,024\n",
      "        Identity-128          [-1, 512, 16, 16]               0\n",
      "            ReLU-129          [-1, 512, 16, 16]               0\n",
      "        Identity-130          [-1, 512, 16, 16]               0\n",
      "          Conv2d-131          [-1, 512, 16, 16]       2,359,296\n",
      "     BatchNorm2d-132          [-1, 512, 16, 16]           1,024\n",
      "          Conv2d-133          [-1, 512, 16, 16]         131,072\n",
      "     BatchNorm2d-134          [-1, 512, 16, 16]           1,024\n",
      "            ReLU-135          [-1, 512, 16, 16]               0\n",
      "      BasicBlock-136          [-1, 512, 16, 16]               0\n",
      "          Conv2d-137          [-1, 512, 16, 16]       2,359,296\n",
      "     BatchNorm2d-138          [-1, 512, 16, 16]           1,024\n",
      "        Identity-139          [-1, 512, 16, 16]               0\n",
      "            ReLU-140          [-1, 512, 16, 16]               0\n",
      "        Identity-141          [-1, 512, 16, 16]               0\n",
      "          Conv2d-142          [-1, 512, 16, 16]       2,359,296\n",
      "     BatchNorm2d-143          [-1, 512, 16, 16]           1,024\n",
      "            ReLU-144          [-1, 512, 16, 16]               0\n",
      "      BasicBlock-145          [-1, 512, 16, 16]               0\n",
      "          Conv2d-146          [-1, 512, 16, 16]       2,359,296\n",
      "     BatchNorm2d-147          [-1, 512, 16, 16]           1,024\n",
      "        Identity-148          [-1, 512, 16, 16]               0\n",
      "            ReLU-149          [-1, 512, 16, 16]               0\n",
      "        Identity-150          [-1, 512, 16, 16]               0\n",
      "          Conv2d-151          [-1, 512, 16, 16]       2,359,296\n",
      "     BatchNorm2d-152          [-1, 512, 16, 16]           1,024\n",
      "            ReLU-153          [-1, 512, 16, 16]               0\n",
      "      BasicBlock-154          [-1, 512, 16, 16]               0\n",
      "     BatchNorm2d-155          [-1, 512, 16, 16]           1,024\n",
      "            ReLU-156          [-1, 512, 16, 16]               0\n",
      "          Conv2d-157         [-1, 1024, 16, 16]       4,719,616\n",
      "            ReLU-158         [-1, 1024, 16, 16]               0\n",
      "          Conv2d-159          [-1, 512, 16, 16]       4,719,104\n",
      "            ReLU-160          [-1, 512, 16, 16]               0\n",
      "          Conv2d-161         [-1, 1024, 16, 16]         525,312\n",
      "            ReLU-162         [-1, 1024, 16, 16]               0\n",
      "    PixelShuffle-163          [-1, 256, 32, 32]               0\n",
      "     BatchNorm2d-164          [-1, 256, 32, 32]             512\n",
      "            ReLU-165          [-1, 512, 32, 32]               0\n",
      "          Conv2d-166          [-1, 512, 32, 32]       2,359,808\n",
      "            ReLU-167          [-1, 512, 32, 32]               0\n",
      "          Conv2d-168          [-1, 512, 32, 32]       2,359,808\n",
      "            ReLU-169          [-1, 512, 32, 32]               0\n",
      "       UnetBlock-170          [-1, 512, 32, 32]               0\n",
      "          Conv2d-171         [-1, 1024, 32, 32]         525,312\n",
      "            ReLU-172         [-1, 1024, 32, 32]               0\n",
      "    PixelShuffle-173          [-1, 256, 64, 64]               0\n",
      "     BatchNorm2d-174          [-1, 128, 64, 64]             256\n",
      "            ReLU-175          [-1, 384, 64, 64]               0\n",
      "          Conv2d-176          [-1, 384, 64, 64]       1,327,488\n",
      "            ReLU-177          [-1, 384, 64, 64]               0\n",
      "          Conv2d-178          [-1, 384, 64, 64]       1,327,488\n",
      "            ReLU-179          [-1, 384, 64, 64]               0\n",
      "       UnetBlock-180          [-1, 384, 64, 64]               0\n",
      "          Conv2d-181          [-1, 768, 64, 64]         295,680\n",
      "            ReLU-182          [-1, 768, 64, 64]               0\n",
      "    PixelShuffle-183        [-1, 192, 128, 128]               0\n",
      "     BatchNorm2d-184         [-1, 64, 128, 128]             128\n",
      "            ReLU-185        [-1, 256, 128, 128]               0\n",
      "          Conv2d-186        [-1, 256, 128, 128]         590,080\n",
      "            ReLU-187        [-1, 256, 128, 128]               0\n",
      "          Conv2d-188        [-1, 256, 128, 128]         590,080\n",
      "            ReLU-189        [-1, 256, 128, 128]               0\n",
      "       UnetBlock-190        [-1, 256, 128, 128]               0\n",
      "          Conv2d-191        [-1, 512, 128, 128]         131,584\n",
      "            ReLU-192        [-1, 512, 128, 128]               0\n",
      "    PixelShuffle-193        [-1, 128, 256, 256]               0\n",
      "     BatchNorm2d-194         [-1, 64, 256, 256]             128\n",
      "            ReLU-195        [-1, 192, 256, 256]               0\n",
      "          Conv2d-196         [-1, 96, 256, 256]         165,984\n",
      "            ReLU-197         [-1, 96, 256, 256]               0\n",
      "          Conv2d-198         [-1, 96, 256, 256]          83,040\n",
      "            ReLU-199         [-1, 96, 256, 256]               0\n",
      "       UnetBlock-200         [-1, 96, 256, 256]               0\n",
      "          Conv2d-201        [-1, 384, 256, 256]          37,248\n",
      "            ReLU-202        [-1, 384, 256, 256]               0\n",
      "    PixelShuffle-203         [-1, 96, 512, 512]               0\n",
      "    ResizeToOrig-204         [-1, 96, 512, 512]               0\n",
      "      MergeLayer-205        [-1, 106, 512, 512]               0\n",
      "          Conv2d-206        [-1, 106, 512, 512]         101,230\n",
      "            ReLU-207        [-1, 106, 512, 512]               0\n",
      "          Conv2d-208        [-1, 106, 512, 512]         101,230\n",
      "            ReLU-209        [-1, 106, 512, 512]               0\n",
      "        ResBlock-210        [-1, 106, 512, 512]               0\n",
      "          Conv2d-211          [-1, 1, 512, 512]             107\n",
      "    ToTensorBase-212          [-1, 1, 512, 512]               0\n",
      "================================================================\n",
      "Total params: 41,268,871\n",
      "Trainable params: 41,268,871\n",
      "Non-trainable params: 0\n",
      "----------------------------------------------------------------\n",
      "Input size (MB): 10.00\n",
      "Forward/backward pass size (MB): 3629.00\n",
      "Params size (MB): 157.43\n",
      "Estimated Total Size (MB): 3796.43\n",
      "----------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "from torchsummary import summary\n",
    "summary(model, (10, 512, 512))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API