Revision a2ec3752f54bfc3b40e7952234fbeb5452ed63e3 authored by Wenqi Li on 19 December 2022, 16:49:12 UTC, committed by Wenqi Li on 19 December 2022, 16:49:16 UTC
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
2 parent s d865264 + b159ce7
Raw File
test_senet.py
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import unittest
from typing import TYPE_CHECKING
from unittest import skipUnless

import torch
from parameterized import parameterized

import monai.networks.nets.senet as se_mod
from monai.networks import eval_mode
from monai.networks.nets import SENet, SENet154, SEResNet50, SEResNet101, SEResNet152, SEResNext50, SEResNext101
from monai.utils import optional_import
from tests.utils import test_is_quick, test_pretrained_networks, test_script_save, testing_data_config

if TYPE_CHECKING:
    import pretrainedmodels

    has_cadene_pretrain = True
else:
    pretrainedmodels, has_cadene_pretrain = optional_import("pretrainedmodels")

device = "cuda" if torch.cuda.is_available() else "cpu"

NET_ARGS = {"spatial_dims": 3, "in_channels": 2, "num_classes": 2}
TEST_CASE_1 = [SENet154, NET_ARGS]
TEST_CASE_2 = [SEResNet50, NET_ARGS]
TEST_CASE_3 = [SEResNet101, NET_ARGS]
TEST_CASE_4 = [SEResNet152, NET_ARGS]
TEST_CASE_5 = [SEResNext50, NET_ARGS]
TEST_CASE_6 = [SEResNext101, NET_ARGS]
TEST_CASE_7 = [
    SENet,
    {
        "spatial_dims": 3,
        "in_channels": 2,
        "num_classes": 2,
        "block": "se_bottleneck",
        "layers": (3, 8, 36, 3),
        "groups": 64,
        "reduction": 16,
    },
]

TEST_CASE_PRETRAINED_1 = [SEResNet50, {"spatial_dims": 2, "in_channels": 3, "num_classes": 2, "pretrained": True}]


class TestSENET(unittest.TestCase):
    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
    def test_senet_shape(self, net, net_args):
        input_data = torch.randn(2, 2, 64, 64, 64).to(device)
        expected_shape = (2, 2)
        net = net(**net_args).to(device)
        with eval_mode(net):
            result = net(input_data)
            self.assertEqual(result.shape, expected_shape)

    @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6])
    def test_script(self, net, net_args):
        net = net(**net_args)
        input_data = torch.randn(2, 2, 64, 64, 64)
        test_script_save(net, input_data)


class TestPretrainedSENET(unittest.TestCase):
    def setUp(self):
        self.original_urls = se_mod.SE_NET_MODELS.copy()
        replace_url = test_is_quick()
        if not replace_url:
            try:
                SEResNet50(pretrained=True, spatial_dims=2, in_channels=3, num_classes=2)
            except OSError as rt_e:
                print(rt_e)
                if "certificate" in str(rt_e):  # [SSL: CERTIFICATE_VERIFY_FAILED]
                    replace_url = True
        if replace_url:
            testing_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "testing_data")
            testing_data_urls = {
                "senet154": {
                    "url": testing_data_config("models", "senet154-c7b49a05", "url"),
                    "filename": "senet154-c7b49a05.pth",
                },
                "se_resnet50": {
                    "url": testing_data_config("models", "se_resnet50-ce0d4300", "url"),
                    "filename": "se_resnet50-ce0d4300.pth",
                },
                "se_resnet101": {
                    "url": testing_data_config("models", "se_resnet101-7e38fcc6", "url"),
                    "filename": "se_resnet101-7e38fcc6.pth",
                },
                "se_resnet152": {
                    "url": testing_data_config("models", "se_resnet152-d17c99b7", "url"),
                    "filename": "se_resnet152-d17c99b7.pth",
                },
                "se_resnext50_32x4d": {
                    "url": testing_data_config("models", "se_resnext50_32x4d-a260b3a4", "url"),
                    "filename": "se_resnext50_32x4d-a260b3a4.pth",
                },
                "se_resnext101_32x4d": {
                    "url": testing_data_config("models", "se_resnext101_32x4d-3b2fe3d8", "url"),
                    "filename": "se_resnext101_32x4d-3b2fe3d8.pth",
                },
            }
            for item in testing_data_urls:
                testing_data_urls[item]["filename"] = os.path.join(testing_dir, testing_data_urls[item]["filename"])
            se_mod.SE_NET_MODELS = testing_data_urls

    def tearDown(self):
        se_mod.SE_NET_MODELS = self.original_urls.copy()

    @parameterized.expand([TEST_CASE_PRETRAINED_1])
    def test_senet_shape(self, model, input_param):
        net = test_pretrained_networks(model, input_param, device)
        input_data = torch.randn(3, 3, 64, 64).to(device)
        expected_shape = (3, 2)
        net = net.to(device)
        with eval_mode(net):
            result = net(input_data)
            self.assertEqual(result.shape, expected_shape)

    @parameterized.expand([TEST_CASE_PRETRAINED_1])
    @skipUnless(has_cadene_pretrain, "Requires `pretrainedmodels` package.")
    def test_pretrain_consistency(self, model, input_param):
        input_data = torch.randn(1, 3, 64, 64).to(device)
        net = test_pretrained_networks(model, input_param, device)
        with eval_mode(net):
            result = net.features(input_data)
        cadene_net = pretrainedmodels.se_resnet50().to(device)
        with eval_mode(cadene_net):
            expected_result = cadene_net.features(input_data)
        # The difference between Cadene's senet and our version is that
        # we use nn.Linear as the FC layer, but Cadene's version uses
        # a conv layer with kernel size equals to 1. It may bring a little difference.
        self.assertTrue(torch.allclose(result, expected_result, rtol=1e-5, atol=1e-5))


if __name__ == "__main__":
    unittest.main()
back to top