https://github.com/google-research/s4l
Tip revision: 8f1cf0555dad64d987309e3bee682cf8390bf48a authored by Avital Oliver on 06 November 2019, 09:59:56 UTC
Add MOAM step 1
Add MOAM step 1
Tip revision: 8f1cf05
resnet.py
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implements Resnet model.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import tensorflow.compat.v1 as tf
from tensorflow.contrib.layers import l2_regularizer
def get_shape_as_list(x):
return x.get_shape().as_list()
def fixed_padding(x, kernel_size):
pad_total = kernel_size - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
x = tf.pad(x, [[0, 0],
[pad_beg, pad_end], [pad_beg, pad_end],
[0, 0]])
return x
def batch_norm(x, training):
return tf.layers.batch_normalization(x, fused=True, training=training)
def identity_norm(x, training):
del training
return x
def maybe_group_conv(x, filters, groups, **kw):
"""Does regular conv or (inefficient) group-conv.
Args:
x: The input image/feature-map.
filters: The total number of filters in the convolution, i.e. number of
channels that the output feature-map will have.
groups: The number of groups. Setting it to 1 leads to regular convolution,
setting it to `filters` leads to separable convolution (but inefficient).
**kw: Any further arguments are passed along to the `conv2d` operation.
Returns:
The output feature-map with `filters` channels.
"""
assert filters % groups == 0, ('Filters ({}) not divisible by '
'groups ({}).'.format(filters, groups))
assert x.shape.rank == 4, 'Only implemented for 4D inputs.'
if groups == 1:
return tf.layers.conv2d(x, filters, **kw)
outputs = []
for i, xi in enumerate(tf.split(x, groups, axis=-1)):
with tf.variable_scope('group_{}'.format(i)):
outputs.append(tf.layers.conv2d(xi, filters // groups, **kw))
return tf.concat(outputs, axis=-1)
def resblock_v1(x, filters, training, # pylint: disable=missing-docstring
strides=1,
dilation=1,
activation_fn=tf.nn.relu,
normalization_fn=batch_norm,
kernel_regularizer=None,
name='unit'):
with tf.variable_scope(name):
# Record input tensor, such that it can be used later in as skip-connection
x_shortcut = x
# Project input if necessary
if (strides > 1) or (filters != x.shape[-1]):
with tf.variable_scope('proj'):
x_shortcut = tf.layers.conv2d(x_shortcut, filters=filters,
kernel_size=1,
strides=strides,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='SAME')
x_shortcut = normalization_fn(x_shortcut, training=training)
# First convolution
with tf.variable_scope('a'):
x = fixed_padding(x, kernel_size=3 + 2 * (dilation - 1))
x = tf.layers.conv2d(x, filters=filters,
kernel_size=3,
dilation_rate=dilation,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='VALID')
x = normalization_fn(x, training=training)
x = activation_fn(x)
# Second convolution
with tf.variable_scope('b'):
x = fixed_padding(x, kernel_size=3)
x = tf.layers.conv2d(x, filters=filters,
strides=strides,
kernel_size=3,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='VALID')
x = normalization_fn(x, training=training)
# Skip connection
x = x_shortcut + x
x = activation_fn(x)
return x
def bottleneck_v1(x, filters, training, # pylint: disable=missing-docstring
strides=1,
dilation=1,
groups=1,
activation_fn=tf.nn.relu,
normalization_fn=batch_norm,
kernel_regularizer=None,
name='unit'):
with tf.variable_scope(name):
# Record input tensor, such that it can be used later in as skip-connection
x_shortcut = x
# Project input if necessary
with tf.variable_scope('proj'):
if (strides > 1) or (filters != x.shape[-1]):
x_shortcut = tf.layers.conv2d(x_shortcut, filters=filters,
kernel_size=1, strides=strides,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='SAME')
x_shortcut = normalization_fn(x_shortcut, training=training)
# Note that ResNeXt doubles middle's channel count!
middle_filters = filters // 4 if groups == 1 else filters // 2
# First convolution
with tf.variable_scope('a'):
# Note, that unlike original Resnet paper we never use stride in the first
# convolution. Instead, we apply stride in the second convolution. The
# reason is that the first convolution has kernel of size 1x1, which
# results in information loss when combined with stride bigger than one.
x = tf.layers.conv2d(x, filters=middle_filters,
kernel_size=1,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='SAME')
x = normalization_fn(x, training=training)
x = activation_fn(x)
# Second convolution
with tf.variable_scope('b'):
x = fixed_padding(x, kernel_size=3 + 2 * (dilation - 1))
x = maybe_group_conv(x, filters=middle_filters, groups=groups,
strides=strides,
kernel_size=3,
dilation_rate=dilation,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='VALID')
x = normalization_fn(x, training=training)
x = activation_fn(x)
# Third convolution
with tf.variable_scope('c'):
x = tf.layers.conv2d(x, filters=filters,
kernel_size=1,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='SAME')
x = normalization_fn(x, training=training)
# Skip connection
x = x_shortcut + x
x = activation_fn(x)
return x
def resblock_v2(x, filters, training, # pylint: disable=missing-docstring
strides=1,
dilation=1,
activation_fn=tf.nn.relu,
normalization_fn=batch_norm,
kernel_regularizer=None,
no_shortcut=False,
out_filters=None,
name='unit'):
with tf.variable_scope(name):
# If the number of output filters is not specified, it defaults to the
# number of input filters.
out_filters = out_filters or filters
# Record input tensor, such that it can be used later in as skip-connection
x_shortcut = x
x = normalization_fn(x, training=training)
x = activation_fn(x)
# Project input if necessary
with tf.variable_scope('proj'):
if (strides > 1) or (out_filters != x.shape[-1]):
x_shortcut = tf.layers.conv2d(x, filters=out_filters, kernel_size=1,
strides=strides,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='VALID')
# First convolution
with tf.variable_scope('a'):
x = fixed_padding(x, kernel_size=3 + 2 * (dilation - 1))
x = tf.layers.conv2d(x, filters=filters,
kernel_size=3,
kernel_regularizer=kernel_regularizer,
use_bias=False,
dilation_rate=dilation,
padding='VALID')
x = normalization_fn(x, training=training)
x = activation_fn(x)
# Second convolution
with tf.variable_scope('b'):
x = fixed_padding(x, kernel_size=3)
x = tf.layers.conv2d(x, filters=out_filters,
strides=strides,
kernel_size=3,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='VALID')
if no_shortcut:
return x
else:
return x + x_shortcut
def bottleneck_v2(x, filters, training, # pylint: disable=missing-docstring
strides=1,
dilation=1,
groups=1,
activation_fn=tf.nn.relu,
normalization_fn=batch_norm,
kernel_regularizer=None,
no_shortcut=False,
out_filters=None,
name='unit'):
with tf.variable_scope(name):
# If the number of output filters is not specified, it defaults to the
# number of input filters.
out_filters = out_filters or filters
# Record input tensor, such that it can be used later in as skip-connection
x_shortcut = x
x = normalization_fn(x, training=training)
x = activation_fn(x)
# Project input if necessary
with tf.variable_scope('proj'):
if (strides > 1) or (out_filters != x.shape[-1]):
x_shortcut = tf.layers.conv2d(x, filters=out_filters, kernel_size=1,
strides=strides,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='VALID')
# Note that ResNeXt doubles middle's channel count!
middle_filters = filters // 4 if groups == 1 else filters // 2
# First convolution
with tf.variable_scope('a'):
# Note, that unlike original Resnet paper we never use stride in the first
# convolution. Instead, we apply stride in the second convolution. The
# reason is that the first convolution has kernel of size 1x1, which
# results in information loss when combined with stride bigger than one.
x = tf.layers.conv2d(x, filters=middle_filters,
kernel_size=1,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='SAME')
# Second convolution
with tf.variable_scope('b'):
x = normalization_fn(x, training=training)
x = activation_fn(x)
# Note, that padding depends on the dilation rate.
x = fixed_padding(x, kernel_size=3 + 2 * (dilation - 1))
x = maybe_group_conv(x, filters=middle_filters, groups=groups,
strides=strides,
kernel_size=3,
dilation_rate=dilation,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='VALID')
# Third convolution
with tf.variable_scope('c'):
x = normalization_fn(x, training=training)
x = activation_fn(x)
x = tf.layers.conv2d(x, filters=out_filters,
kernel_size=1,
kernel_regularizer=kernel_regularizer,
use_bias=False,
padding='SAME')
if no_shortcut:
return x
else:
return x + x_shortcut
def resnet(x, # pylint: disable=missing-docstring
is_training,
num_classes=1000,
filters_factor=4,
weight_decay=1e-4,
include_root_block=True,
root_conv_size=7, root_conv_stride=2,
root_pool_size=3, root_pool_stride=2,
activation_fn=tf.nn.relu,
last_relu=True,
normalization_fn=batch_norm,
strides=(2, 2, 2),
dilations=(1, 1, 1, 1),
num_layers=(3, 4, 6, 3),
global_pool=True,
unit='bottleneck',
mode='v2',
representation_size=None,
spatial_squeeze=True,
groups=1):
unit_kw = {
'activation_fn': activation_fn,
'normalization_fn': normalization_fn,
'training': is_training,
}
mult = 1
if unit == 'bottleneck':
unit = bottleneck_v2 if mode == 'v2' else bottleneck_v1
mult = 4
unit_kw['groups'] = groups
elif unit == 'resblock':
unit = resblock_v2 if mode == 'v2' else resblock_v1
assert groups == 1, 'Groups not supported in "resblock".'
else:
raise ValueError('Unknown resnet unit: %s' % unit)
strides = list(strides)[::-1]
dilations = list(dilations)[::-1]
num_layers = list(num_layers)[::-1]
end_points = {}
filters = 16 * filters_factor
kernel_regularizer = l2_regularizer(scale=weight_decay)
unit_kw['kernel_regularizer'] = kernel_regularizer
if include_root_block:
with tf.variable_scope('root_block'):
x = fixed_padding(x, kernel_size=root_conv_size)
x = tf.layers.conv2d(x, filters=filters,
kernel_size=root_conv_size,
strides=root_conv_stride,
padding='VALID', use_bias=False,
kernel_regularizer=kernel_regularizer)
if mode == 'v1':
x = normalization_fn(x, training=is_training)
x = activation_fn(x)
x = fixed_padding(x, kernel_size=root_pool_size)
x = tf.layers.max_pooling2d(x, pool_size=root_pool_size,
strides=root_pool_stride, padding='VALID')
end_points['after_root'] = x
with tf.variable_scope('block1'):
filters *= mult
unit_kw.update({'dilation': dilations.pop()})
for i in range(num_layers.pop()):
x = unit(x, filters, strides=1, name='unit%d' % (i + 1), **unit_kw)
end_points['block1'] = x
with tf.variable_scope('block2'):
filters *= 2
unit_kw.update({'dilation': dilations.pop()})
x = unit(x, filters, strides=strides.pop(), name='unit1', **unit_kw)
for i in range(1, num_layers.pop()):
x = unit(x, filters, strides=1, name='unit%d' % (i + 1), **unit_kw)
end_points['block2'] = x
with tf.variable_scope('block3'):
filters *= 2
unit_kw.update({'dilation': dilations.pop()})
x = unit(x, filters, strides=strides.pop(), name='unit1', **unit_kw)
for i in range(1, num_layers.pop()):
x = unit(x, filters, strides=1, name='unit%d' % (i + 1), **unit_kw)
end_points['block3'] = x
with tf.variable_scope('block4'):
filters *= 2
unit_kw.update({'dilation': dilations.pop()})
nlayers = num_layers.pop()
if nlayers > 1:
x = unit(x, filters, strides=strides.pop(), name='unit1', **unit_kw)
for i in range(1, nlayers - 1):
x = unit(x, filters, strides=1, name='unit%d' % (i + 1), **unit_kw)
# representation_size is the number of dimensions of the final output,
# right before (and after) the global average pooling. By default, it's
# simply the number of filters that we got there with the given
# architecture, but optionally we may want to control it explicitly and
# independently.
x = unit(x, representation_size or filters, strides=1,
name='unit%d' % nlayers, **unit_kw)
else:
# But in the case of just one block, do everything in that block.
x = unit(x, representation_size or filters, strides=strides.pop(),
name='unit1', **unit_kw)
end_points['block4'] = x
if (mode == 'v1') and (not last_relu):
raise ValueError('last_relu should be set to True in the v1 mode.')
if mode == 'v2':
x = normalization_fn(x, training=is_training)
if last_relu:
x = activation_fn(x)
if global_pool:
x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
end_points['pre_logits'] = tf.squeeze(x, [1, 2]) if spatial_squeeze else x
else:
end_points['pre_logits'] = x
if num_classes:
with tf.variable_scope('head'):
logits = tf.layers.conv2d(x, filters=num_classes,
kernel_size=1,
kernel_regularizer=kernel_regularizer)
if global_pool and spatial_squeeze:
logits = tf.squeeze(logits, [1, 2])
end_points['logits'] = logits
return logits, end_points
else:
return end_points['pre_logits'], end_points
resnet18 = functools.partial(resnet, num_layers=(2, 2, 2, 2),
unit='resblock')
resnet34 = functools.partial(resnet, num_layers=(3, 4, 6, 3),
unit='resblock')
resnet50 = functools.partial(resnet, num_layers=(3, 4, 6, 3),
unit='bottleneck')
resnet101 = functools.partial(resnet, num_layers=(3, 4, 23, 3),
unit='bottleneck')
resnet152 = functools.partial(resnet, num_layers=(3, 8, 36, 3),
unit='bottleneck')
# Experimental code ########################################
# "Reversible" resnet ######################################
# Invertible residual block as outlined in https://arxiv.org/abs/1707.04585
def bottleneck_rev(x, training, # pylint: disable=missing-docstring
activation_fn=tf.nn.relu,
normalization_fn=batch_norm,
dilation=1,
kernel_regularizer=None,
simple=False,
unit='bottleneck',
name='unit_rev'):
if unit == 'bottleneck':
unit = bottleneck_v2
elif unit == 'resblock':
unit = resblock_v2
else:
raise ValueError('Unknown resnet unit: %s' % unit)
x1, x2 = tf.split(x, 2, 3)
with tf.variable_scope(name):
y1 = x1 + unit(x2, x2.shape[-1], training,
strides=1,
activation_fn=activation_fn,
normalization_fn=normalization_fn,
kernel_regularizer=kernel_regularizer,
dilation=dilation,
no_shortcut=True,
name='unit1')
if simple:
y2 = x2
# The swap of 'y' parts is intentional here as 'simple' block processes
# only one part of the input. Thus, without this swap only one part will
# be repeatedly processed. This is a standard practice from e.g. RealNVP
# or iRevnet papers.
return tf.concat([y2, y1], axis=3)
else:
y2 = x2 + unit(y1, y1.shape[-1], training,
strides=1,
activation_fn=activation_fn,
normalization_fn=normalization_fn,
kernel_regularizer=kernel_regularizer,
dilation=dilation,
no_shortcut=True,
name='unit2')
return tf.concat([y1, y2], axis=3)
# This operation is not strictly speaking invertible. However, realistically,
# it always preserves large amount of information and can be inverted up to
# some error.
def pool_and_double_channels(x, stride):
if stride > 1:
x = tf.layers.average_pooling2d(x, pool_size=stride, strides=stride,
padding='SAME')
return tf.pad(x, [[0, 0], [0, 0], [0, 0],
[x.shape[3] // 2, x.shape[3] // 2]])
def revnet(x, # pylint: disable=missing-docstring
is_training,
num_classes=1000,
filters_factor=4,
weight_decay=1e-4,
include_root_block=True,
root_conv_size=7, root_conv_stride=2,
root_pool_size=3, root_pool_stride=2,
strides=(2, 2, 2),
dilations=(1, 1, 1, 1),
num_layers=(3, 4, 6, 3),
global_pool=True,
activation_fn=tf.nn.relu,
normalization_fn=batch_norm,
last_relu=False,
mode='v2',
inside_unit='bottleneck',
representation_size=None,
regularize_last_proj=False,
spatial_squeeze=True):
del mode # unused parameter, exists for compatibility with resnet function
# Use simple block. Note that two consecutive simple blocks are equivalent
# to the normal block
unit = functools.partial(bottleneck_rev, simple=True)
mult = 1
if inside_unit == 'bottleneck':
mult = 4
strides = list(strides)[::-1]
dilations = list(dilations)[::-1]
num_layers = list(num_layers)[::-1]
end_points = {}
filters = 16 * filters_factor
kernel_regularizer = l2_regularizer(scale=weight_decay)
# First convolution serves as random projection in order to increase number
# of channels. It is not possible to skip it.
with tf.variable_scope('root_block'):
x = fixed_padding(x, kernel_size=root_conv_size)
x = tf.layers.conv2d(x, filters=filters * mult,
kernel_size=root_conv_size,
strides=root_conv_stride,
padding='VALID', use_bias=False,
kernel_regularizer=None)
if include_root_block:
x = fixed_padding(x, kernel_size=root_pool_size)
x = tf.layers.max_pooling2d(
x, pool_size=root_pool_size, strides=root_pool_stride,
padding='VALID')
end_points['after_root'] = x
params = {'activation_fn': activation_fn,
'normalization_fn': normalization_fn,
'training': is_training,
'kernel_regularizer': kernel_regularizer,
'unit': inside_unit
}
with tf.variable_scope('block1'):
params.update({'dilation': dilations.pop()})
for i in range(num_layers.pop()):
x = unit(x, name='block%i' % i, **params)
x = pool_and_double_channels(x, strides.pop())
end_points['block1'] = x
with tf.variable_scope('block2'):
params.update({'dilation': dilations.pop()})
for i in range(num_layers.pop()):
x = unit(x, name='block%i' % i, **params)
x = pool_and_double_channels(x, strides.pop())
end_points['block2'] = x
with tf.variable_scope('block3'):
params.update({'dilation': dilations.pop()})
for i in range(num_layers.pop()):
x = unit(x, name='block%i' % i, **params)
x = pool_and_double_channels(x, strides.pop())
end_points['block3'] = x
with tf.variable_scope('block4'):
params.update({'dilation': dilations.pop()})
for i in range(num_layers.pop()):
x = unit(x, name='block%i' % i, **params)
end_points['block4'] = x
if representation_size:
with tf.variable_scope('resize'):
x = activation_fn(normalization_fn(x, training=is_training))
kern_reg = kernel_regularizer if regularize_last_proj else None
x = tf.layers.conv2d(x, filters=representation_size, use_bias=False,
kernel_size=1, kernel_regularizer=kern_reg)
with tf.variable_scope('head'):
x = normalization_fn(x, training=is_training)
if last_relu:
x = activation_fn(x)
if global_pool:
x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
end_points['pre_logits'] = tf.squeeze(x, [1, 2]) if spatial_squeeze else x
else:
end_points['pre_logits'] = x
if num_classes:
logits = tf.layers.conv2d(x, filters=num_classes,
kernel_size=1,
kernel_regularizer=kernel_regularizer)
if global_pool and spatial_squeeze:
logits = tf.squeeze(logits, [1, 2])
end_points['logits'] = logits
return logits, end_points
else:
return end_points['pre_logits'], end_points
revnet18 = functools.partial(revnet, num_layers=(2, 2, 2, 2),
inside_unit='resblock')
revnet34 = functools.partial(revnet, num_layers=(3, 4, 6, 3),
inside_unit='resblock')
revnet50 = functools.partial(revnet, num_layers=(3, 4, 6, 3),
inside_unit='bottleneck')
revnet101 = functools.partial(revnet, num_layers=(3, 4, 23, 3),
inside_unit='bottleneck')
revnet152 = functools.partial(revnet, num_layers=(3, 8, 36, 3),
inside_unit='bottleneck')
# Even more experimental code ########################################
# Fully-Reversible resnet aka iRevnet ################################
# Reimplementation of iRevnet from: https://openreview.net/forum?id=HJsjkMb0Z
def irevnet300(x, # pylint: disable=missing-docstring
is_training,
num_classes=1000,
weight_decay=1e-4,
activation_fn=tf.nn.relu,
normalization_fn=batch_norm):
unit = functools.partial(bottleneck_rev, simple=True)
end_points = {}
kernel_regularizer = l2_regularizer(scale=weight_decay)
# Do inv. pooling first
x = tf.space_to_depth(x, 2)
layer_counts = [1, 6, 16, 72, 5]
params = {'activation_fn': activation_fn,
'normalization_fn': normalization_fn,
'training': is_training,
'kernel_regularizer': kernel_regularizer
}
for num_block, layer_count in enumerate(layer_counts):
for _ in range(layer_count):
x = unit(x, **params)
end_points['block%i' % (num_block + 1)] = x
if num_block < (len(layer_counts) - 1):
x = tf.space_to_depth(x, 2)
x = normalization_fn(x, training=is_training)
end_points['last_invertible'] = x
# Non-invertible part starts here
x = activation_fn(x)
x = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
end_points['pre_logits'] = tf.squeeze(x, [1, 2])
logits = tf.squeeze(tf.layers.conv2d(x, filters=num_classes,
kernel_size=1,
kernel_regularizer=kernel_regularizer),
[1, 2])
end_points['logits'] = logits
return logits, end_points