https://github.com/szagoruyko/diracnets
Raw File
Tip revision: 3c6f9863e983fdd56db2372617d9fd0d2c838125 authored by Sergey Zagoruyko on 09 June 2018, 17:46:56 UTC
Merge pull request #19 from szagoruyko/pytorch0.4
Tip revision: 3c6f986
diracnet.py
from functools import partial
from nested_dict import nested_dict
import torch
import torch.nn.functional as F
from torch.nn.init import dirac_, kaiming_normal_
from torch.nn.parallel._functions import Broadcast
from torch.nn.parallel import scatter, parallel_apply, gather
from torch import nn
from torch.utils import model_zoo


def cast(params, dtype='float'):
    if isinstance(params, dict):
        return {k: cast(v, dtype) for k,v in params.items()}
    else:
        return getattr(params.cuda() if torch.cuda.is_available() else params, dtype)()


def conv_params(ni, no, k=1):
    return kaiming_normal_(torch.Tensor(no, ni, k, k))


def linear_params(ni, no):
    return {'weight': kaiming_normal_(torch.Tensor(no, ni)), 'bias': torch.zeros(no)}


def bnparams(n):
    return {'weight': torch.rand(n),
            'bias': torch.zeros(n),
            'running_mean': torch.zeros(n),
            'running_var': torch.ones(n)}


def data_parallel(f, input, params, mode, device_ids, output_device=None):
    device_ids = list(device_ids)
    if output_device is None:
        output_device = device_ids[0]

    if len(device_ids) == 1:
        return f(input, params, mode)

    params_all = Broadcast.apply(device_ids, *params.values())
    params_replicas = [{k: params_all[i + j*len(params)] for i, k in enumerate(params.keys())}
                       for j in range(len(device_ids))]

    replicas = [partial(f, params=p, mode=mode)
                for p in params_replicas]
    inputs = scatter([input], device_ids)
    outputs = parallel_apply(replicas, inputs)
    return gather(outputs, output_device)


def flatten(params):
    return {'.'.join(k): v for k, v in nested_dict(params).items_flat() if v is not None}


def batch_norm(x, params, base, mode):
    return F.batch_norm(x, weight=params[base + '.weight'],
                        bias=params[base + '.bias'],
                        running_mean=params[base + '.running_mean'],
                        running_var=params[base + '.running_var'],
                        training=mode)


def print_tensor_dict(params):
    kmax = max(len(key) for key in params.keys())
    for i, (key, v) in enumerate(params.items()):
        print(str(i).ljust(5), key.ljust(kmax + 3), str(tuple(v.shape)).ljust(23), torch.typename(v), v.requires_grad)


def set_requires_grad_except_bn_(params):
    for k, v in params.items():
        if not k.endswith('running_mean') and not k.endswith('running_var'):
            v.requires_grad = True


def size2name(size):
    return 'eye' + '_'.join(map(str, size))


def block(o, params, base, mode, j):
    w = params[base + '.conv']
    alpha = params[base + '.alpha'].view(-1, 1, 1, 1)
    beta = params[base + '.beta'].view(-1, 1, 1, 1)
    delta = params[size2name(w.shape)]
    w = beta * F.normalize(w.view(w.shape[0], -1)).view_as(w) + alpha * delta
    o = F.conv2d(F.relu(o), w, stride=1, padding=1)
    o = batch_norm(o, params, base + '.bn', mode)
    return o


def group(o, params, base, mode, count):
    for i in range(count):
        o = block(o, params, '%s.block%d' % (base, i), mode, i)
    return o


def define_diracnet(depth, width, dataset):

    def gen_group_params(ni, no, count):
        return {'block%d' % i: {'conv': conv_params(ni if i == 0 else no, no, k=3),
                                'alpha': torch.ones(no).fill_(1),
                                'beta': torch.ones(no).fill_(0.1),
                                'bn': bnparams(no)} for i in range(count)}

    if dataset.startswith('CIFAR'):
        n = (depth - 4) // 6
        widths = [int(v * width) for v in (16, 32, 64)]

        def f(inputs, params, mode):
            o = F.conv2d(inputs, params['conv'], padding=1)
            o = F.relu(batch_norm(o, params, 'bn', mode))
            o = group(o, params, 'group0', mode, n * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, 'group1', mode, n * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, 'group2', mode, n * 2)
            o = F.avg_pool2d(F.relu(o), 8)
            o = F.linear(o.view(o.size(0), -1), params['fc.weight'], params['fc.bias'])
            return o

        params = {
            'conv': kaiming_normal_(torch.Tensor(widths[0], 3, 3, 3)),
            'bn': bnparams(widths[0]),
            'group0': gen_group_params(widths[0], widths[0], n * 2),
            'group1': gen_group_params(widths[0], widths[1], n * 2),
            'group2': gen_group_params(widths[1], widths[2], n * 2),
            'fc': linear_params(widths[2], 10 if dataset == 'CIFAR10' else 100),
        }

    elif dataset == 'ImageNet':
        definitions = {18: [2, 2, 2, 2],
                       34: [3, 4, 6, 3]}
        widths = [int(width * v) for v in (64, 128, 256, 512)]
        blocks = definitions[depth]

        def f(inputs, params, mode):
            o = F.conv2d(inputs, params['conv'], padding=3, stride=2)
            o = batch_norm(o, params, 'bn', mode)
            o = F.max_pool2d(o, 3, 2, 1)
            o = group(o, params, 'group0', mode, blocks[0] * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, 'group1', mode, blocks[1] * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, 'group2', mode, blocks[2] * 2)
            o = F.max_pool2d(o, 2)
            o = group(o, params, 'group3', mode, blocks[3] * 2)
            o = F.avg_pool2d(F.relu(o), o.size(-1))
            o = F.linear(o.view(o.size(0), -1), params['fc.weight'], params['fc.bias'])
            return o

        params = {
            'conv': kaiming_normal_(torch.Tensor(widths[0], 3, 7, 7)),
            'group0': gen_group_params(widths[0], widths[0], 2 * blocks[0]),
            'group1': gen_group_params(widths[0], widths[1], 2 * blocks[1]),
            'group2': gen_group_params(widths[1], widths[2], 2 * blocks[2]),
            'group3': gen_group_params(widths[2], widths[3], 2 * blocks[3]),
            'bn': bnparams(widths[0]),
            'fc': linear_params(widths[-1], 1000),
        }
    else:
        raise ValueError('dataset not understood')

    flat_params = flatten(params)

    flat_params = {k: cast(v.data) for k, v in flat_params.items()}

    set_requires_grad_except_bn_(flat_params)

    for k, v in list(flat_params.items()):
        if k.find('.conv') > -1:
            flat_params[size2name(v.size())] = dirac_(v.data.clone())

    return f, flat_params


model_urls = {
    'diracnet18': 'https://s3.amazonaws.com/modelzoo-networks/diracnet18v2folded-a2174e15.pth',
    'diracnet34': 'https://s3.amazonaws.com/modelzoo-networks/diracnet34v2folded-dfb15d34.pth'
}


class DiracNet(nn.Module):

    widths = (64, 128, 256, 512)
    block_depths = {18: [v * 2 for v in (2, 2, 2, 2)],
                    34: [v * 2 for v in (3, 4, 6, 3)]}

    def __init__(self, depth=18):
        super().__init__()
        self.features = nn.Sequential()
        n_channels = self.widths[0]
        self.features.add_module('conv', nn.Conv2d(3, n_channels, kernel_size=7, stride=2, padding=3))
        self.features.add_module('max_pool0', nn.MaxPool2d(3, 2, 1))
        for group_id, (width, block_depth) in enumerate(zip(self.widths, self.block_depths[depth])):
            for block_id in range(block_depth):
                name = 'group{}.block{}.'.format(group_id, block_id)
                self.features.add_module(name + 'relu', nn.ReLU())
                self.features.add_module(name + 'conv', nn.Conv2d(n_channels, width, kernel_size=3, padding=1))
                n_channels = width
            if group_id != 3:
                self.features.add_module('max_pool{}'.format(group_id + 1), nn.MaxPool2d(2))
            else:
                self.features.add_module('last_relu', nn.ReLU())
                self.features.add_module('avg_pool', nn.AvgPool2d(7))
        self.fc = nn.Linear(in_features=512, out_features=1000)

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x


def diracnet18(pretrained=False):
    model = DiracNet(18)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['diracnet18']))
    return model


def diracnet34(pretrained=False):
    model = DiracNet(34)
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['diracnet34']))
    return model

back to top