https://github.com/galmetzer/z2p
Raw File
Tip revision: 082dac85b5bdb8a6ce31dceea6561efd68b16633 authored by Gal Metzer on 05 June 2022, 10:43:42 UTC
init
Tip revision: 082dac8
networks.py
import torch
import torch.nn.functional as F
from torch import nn

W_SIZE = 512


def calc_mean_std(feat, eps=1e-5):
    # eps is a small value added to the variance to avoid divide-by-zero.
    size = feat.size()
    assert (len(size) == 4)
    N, C = size[:2]
    feat_var = feat.view(N, C, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(N, C, 1, 1)
    feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
    return feat_mean, feat_std


def adain(content_feat, style_feat):
    assert (content_feat.size()[:2] == style_feat[0].size()[:2]) and (
                content_feat.size()[:2] == style_feat[1].size()[:2])
    size = content_feat.size()
    style_mean, style_std = style_feat
    style_mean, style_std = style_mean.unsqueeze(-1).unsqueeze(-1), style_std.unsqueeze(-1).unsqueeze(-1)
    content_mean, content_std = calc_mean_std(content_feat)

    normalized_feat = (content_feat - content_mean.expand(
        size)) / content_std.expand(size)
    return normalized_feat * style_std.expand(size) + style_mean.expand(size)


class FullyConnected(nn.Module):
    def __init__(self, input_channels: int, output_channels: int, layers=3):
        super(FullyConnected, self).__init__()
        self.channels = torch.linspace(input_channels, output_channels, layers + 1).long()
        self.layers = nn.Sequential(
            *[nn.Linear(self.channels[i].item(), self.channels[i + 1].item()) for i in range(len(self.channels) - 1)]
        )

    def forward(self, x):
        return self.layers(x)


class Affine(nn.Module):
    def __init__(self, input_channels: int, output_channels):
        super(Affine, self).__init__()
        self.lin = nn.Linear(input_channels, output_channels)
        bias = torch.zeros(output_channels)
        nn.init.normal_(bias, 0, 1)
        self.bias = nn.Parameter(bias)

    def forward(self, x):
        return self.lin(x) + self.bias


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None, ada=False, padding='zeros'):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.ada = ada
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, padding_mode=padding)
        if ada:
            self.a1_mean = Affine(W_SIZE, mid_channels)
            self.a1_std = Affine(W_SIZE, mid_channels)
        else:
            self.norm1 = nn.InstanceNorm2d(mid_channels, affine=True)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding)

        if ada:
            self.a2_mean = Affine(W_SIZE, out_channels)
            self.a2_std = Affine(W_SIZE, out_channels)
        else:
            self.norm2 = nn.InstanceNorm2d(out_channels, affine=True)

    def forward(self, x, w=None):
        if self.ada:
            assert w is not None

        x = self.conv1(x)

        if self.ada:
            x = adain(x, (self.a1_mean(w), self.a1_std(w)))
        else:
            x = self.norm1(x)
        x = self.relu(x)

        x = self.conv2(x)
        if self.ada:
            x = adain(x, (self.a2_mean(w), self.a2_std(w)))
        else:
            x = self.norm2(x)
        x = self.relu(x)

        return x


class DiluteConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, dilation, padding='zeros'):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               padding=1 + dilation, dilation=dilation, padding_mode=padding)
        self.norm1 = nn.InstanceNorm2d(out_channels, affine=True)

    def forward(self, x, y=None):
        if y is not None:
            x = torch.cat([x, y], dim=1)

        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu(x)
        return x


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, ada=False, padding='zeros'):
        super().__init__()
        self.max_pool = nn.MaxPool2d(2)
        self.double_conv = DoubleConv(in_channels, out_channels, ada=ada, padding=padding)

    def forward(self, x, w=None):
        x = self.max_pool(x)
        return self.double_conv(x, w)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True, ada=False, padding='zeros'):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2, ada=ada)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels, ada=ada)

    def forward(self, x1, x2, w=None):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x, w)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels, padding='zeros'):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding_mode=padding)

    def forward(self, x):
        return self.conv(x)
back to top