https://gitlab.com/mcoavoux/mtgpy-release-findings-2021.git
Tip revision: c9972219cd75049269d26632d2bb79619d661298 authored by mcoavoux on 20 May 2021, 13:04:44 UTC
up readme
up readme
Tip revision: c997221
sentence_encoders.py
import torch.nn as nn
import torch.nn.init
import torch.nn.functional as F
# Stolen from: https://github.com/yzhangcs/biaffine-parser/blob/master/parser/modules/dropout.py
class SharedDropout(nn.Module):
def __init__(self, p=0.5, batch_first=True):
super(SharedDropout, self).__init__()
self.p = p
self.batch_first = batch_first
def forward(self, x):
if self.training and self.p > 0.:
if self.batch_first:
mask = self.get_mask(x[:, 0], self.p)
else:
mask = self.get_mask(x[0], self.p)
x *= mask.unsqueeze(1) if self.batch_first else mask
return x
@staticmethod
def get_mask(x, p):
mask = x.new_empty(x.shape).bernoulli_(1 - p)
mask = mask / (1 - p)
return mask
class IndependentDropout(nn.Module):
def __init__(self, p=0.5):
super(IndependentDropout, self).__init__()
self.p = p
def forward(self, *items):
if self.training and self.p > 0.:
# with -1 it should work for any input dim
#masks = [x.new_empty(x.shape[:2]).bernoulli_(1 - self.p)
# for x in items]
masks = [x.new_empty(x.shape[:-1]).bernoulli_(1 - self.p)
for x in items]
total = sum(masks)
scale = len(items) / total.max(torch.ones_like(total))
masks = [mask * scale for mask in masks]
items = [item * mask.unsqueeze(dim=-1)
for item, mask in zip(items, masks)]
return items
## Stolen from: https://github.com/FilippoC/pydestruct/
## in turn inspired by https://github.com/nikitakit/self-attentive-parser
#class LayerNorm(nn.Module):
# def __init__(self, dim, mean=0., std=1., fixed=False, eps=1e-6, ball=False):
# super(LayerNorm, self).__init__()
# self.eps = eps
# self.ball = ball
# if fixed:
# self.target_mean = mean
# self.target_std = std
# else:
# self.target_mean = nn.Parameter(torch.empty(dim).fill_(mean))
# self.target_std = nn.Parameter(torch.empty(dim).fill_(std))
# def forward(self, x):
# mean = x.mean(-1, keepdim=True)
# std = torch.sqrt(torch.mean((x - mean).pow(2), dim=-1, keepdim=True) + self.eps)
# if self.ball:
# std = std.clamp(1.)
# return self.target_std * (x - mean) / std + self.target_mean
class LayerNorm(nn.Module):
# Layer norm from https://github.com/nikitakit/self-attentive-parser
def __init__(self, d_hid, eps=1e-3, affine=True):
super(LayerNorm, self).__init__()
self.eps = eps
self.affine = affine
if self.affine:
self.a_2 = nn.Parameter(torch.ones(d_hid), requires_grad=True)
self.b_2 = nn.Parameter(torch.zeros(d_hid), requires_grad=True)
def forward(self, z):
if z.size(-1) == 1:
return z
mu = torch.mean(z, keepdim=True, dim=-1)
sigma = torch.std(z, keepdim=True, dim=-1)
ln_out = (z - mu.expand_as(z)) / (sigma.expand_as(z) + self.eps)
if self.affine:
ln_out = ln_out * self.a_2.expand_as(ln_out) + self.b_2.expand_as(ln_out)
return ln_out
class MultiHeadAttention(nn.Module):
def __init__(self, input_dim, n_heads, query_dim, values_dim, output_dim, att_proj_bias=False, att_dropout=0.):
super().__init__()
self.n_heads = n_heads
self.query = nn.Linear(input_dim, n_heads * query_dim, bias=att_proj_bias)
self.key = nn.Linear(input_dim, n_heads * query_dim, bias=att_proj_bias)
self.value = nn.Linear(input_dim, n_heads * values_dim, bias=att_proj_bias)
# For attention dropout we use "standard" dropout
# we should take into account the mask when using att!
self.att_dropout = nn.Dropout(att_dropout)
self.temper = query_dim ** 0.5
self.output_proj = nn.Linear(n_heads * values_dim, output_dim)
torch.nn.init.xavier_uniform_(self.query.weight)
torch.nn.init.xavier_uniform_(self.key.weight)
torch.nn.init.xavier_uniform_(self.value.weight)
torch.nn.init.xavier_uniform_(self.output_proj.weight)
def forward(self, input_features, mask=None, mask_value=-float('inf')):
batch_size = input_features.shape[0]
sentence_size = input_features.shape[1]
# (batch size, sentence size, proj size * n head)
query = self.query(input_features)
key = self.key(input_features)
# (batch size, sentence size, values size * n_head)
values = self.value(input_features)
# (batch size, n heads, sentence size, proj size)
query = query.reshape(batch_size, query.shape[1], self.n_heads, -1).transpose(1, 2)
key = key.reshape(batch_size, key.shape[1], self.n_heads, -1).transpose(1, 2)
values = values.reshape(batch_size, values.shape[1], self.n_heads, -1).transpose(1, 2)
# (batch size, n heads, sentence size, sentence size)
att_scores = query @ key.transpose(-1, -2)
att_scores /= self.temper
if mask is not None:
att_scores.data.masked_fill_(mask.reshape(att_scores.shape[0], 1, 1, att_scores.shape[-1]), mask_value)
# compute attention
att = F.softmax(att_scores, dim=-1)
att = self.att_dropout(att)
# aggregate values:
# the softmax dim is the last dim and the sentence dim is last-1 dim,
# so this is ok
# (batch size, n head, sentence size, values size)
values = att @ values
# now we must concatenate all heads
# (batch size, sentence size, values size * n heads)
values = values.transpose(1, 2).reshape(batch_size, sentence_size, -1)
values = self.output_proj(values)
return values
# I can't see the layer norm in the paper (sec. 3.3)
class PositionwiseFeedForward(nn.Module):
def __init__(self, input_dim, hidden_dim, dropout=0.1, shared_dropout=True):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(input_dim, hidden_dim)
self.w_2 = nn.Linear(hidden_dim, input_dim)
self.relu = nn.ReLU()
if shared_dropout:
# we can use it because input is (batch, n word, n features)
self.relu_dropout = SharedDropout(dropout)
else:
self.relu_dropout = nn.Dropout(dropout)
torch.nn.init.xavier_uniform_(self.w_1.weight)
torch.nn.init.xavier_uniform_(self.w_2.weight)
def forward(self, x):
inter = self.relu_dropout(self.relu(self.w_1(x)))
output = self.w_2(inter)
return output
class Layer(nn.Module):
def __init__(self,
input_dim,
n_heads,
query_dim,
values_dim,
ff_hidden_dim,
att_dropout=0.,
ff_dropout=0.,
residual_att_dropout=0.,
residual_ff_dropout=0.,
att_proj_bias=False,
shared_dropout=True,
pre_ln=False,
ball_norm=False
):
super(Layer, self).__init__()
self.self_attn = MultiHeadAttention(
input_dim=input_dim,
n_heads=n_heads,
query_dim=query_dim,
values_dim=values_dim,
output_dim=input_dim, # must be the same size because of the residual connection
att_dropout=att_dropout,
att_proj_bias=att_proj_bias
)
self.feed_forward = PositionwiseFeedForward(
input_dim,
hidden_dim=ff_hidden_dim,
dropout=ff_dropout,
shared_dropout=shared_dropout
)
self.layer_norm1 = LayerNorm(input_dim)
self.layer_norm2 = LayerNorm(input_dim)
self.pre_ln = pre_ln
if shared_dropout:
# ok because input is (batch, n word, features)
self.dropout1 = SharedDropout(residual_att_dropout)
self.dropout2 = SharedDropout(residual_ff_dropout)
else:
self.dropout1 = nn.Dropout(residual_att_dropout)
self.dropout2 = nn.Dropout(residual_ff_dropout)
def forward(self, x, mask=None, mask_value=-float('inf')):
if self.pre_ln:
x1 = self.layer_norm1(x)
x1 = self.self_attn(x1, mask=mask, mask_value=mask_value)
x1 = self.dropout1(x1) + x
x2 = self.layer_norm2(x1)
x2 = self.feed_forward(x2)
x2 = self.dropout2(x2) + x1
else:
x1 = self.self_attn(x, mask=mask, mask_value=mask_value)
x1 = self.dropout1(x1) + x
x1 = self.layer_norm1(x1)
x2 = self.feed_forward(x1)
x2 = self.dropout2(x2) + x1
x2 = self.layer_norm2(x2)
return x2
class Transformer(nn.Module):
def __init__(self,
num_layers,
input_dim,
n_heads,
ff_dim_hidden,
query_dim,
values_dim,
att_dropout=0.,
ff_dropout=0.,
residual_att_dropout=0.,
residual_ff_dropout=0.,
norm_input=True,
att_proj_bias=False,
shared_dropout=True,
pre_ln=False,
ball_norm=False
):
super().__init__()
self.layer_norm = LayerNorm(input_dim) if norm_input else None
self.layers = nn.ModuleList([
Layer(
input_dim=input_dim,
n_heads=n_heads,
query_dim=query_dim,
values_dim=values_dim,
ff_hidden_dim=ff_dim_hidden,
att_dropout=att_dropout,
ff_dropout=ff_dropout,
residual_att_dropout=residual_att_dropout,
residual_ff_dropout=residual_ff_dropout,
att_proj_bias=att_proj_bias,
shared_dropout=shared_dropout,
pre_ln=pre_ln
)
for _ in range(num_layers)
])
# mask: but be false at word position and true at masked positions
def forward(self, emb, lengths=None, mask=None, mask_value=-float('inf')):
if lengths is not None and mask is None:
# build mask from lengths
lengths = lengths.to(emb.device)
mask = torch.arange(emb.shape[1], device=emb.device).reshape(1, -1) >= lengths.reshape(-1, 1)
if self.layer_norm is not None:
emb = self.layer_norm(emb)
for i in range(len(self.layers)):
emb = self.layers[i](emb, mask=mask, mask_value=mask_value)
return emb
class TransformerNetwork(nn.Module):
def __init__(self, args, d_input):
super(TransformerNetwork, self).__init__()
if args.trans_position_embs:
d = args.trans_dmodel - d_input if args.trans_position_concat else d_input
self.position_table = nn.Parameter(torch.FloatTensor(args.trans_position_max, d))
if args.trans_position_concat:
d_input += d
self.position_concat = args.trans_position_concat
self.position_dropout = IndependentDropout(args.trans_position_dropout)
else:
self.position_table = None
if d_input != args.trans_dmodel:
raise RuntimeError("Input size mismatch: %i - %i" % (d_input, args.trans_dmodel))
self.encoder = Transformer(
num_layers=args.trans_n_layers,
input_dim=d_input,
n_heads=args.trans_n_heads,
ff_dim_hidden=args.trans_ff_hidden_dim,
query_dim=args.trans_query_dim,
values_dim=args.trans_value_dim,
att_dropout=args.trans_att_dropout,
ff_dropout=args.trans_ff_dropout,
residual_att_dropout=args.trans_residual_att_dropout,
residual_ff_dropout=args.trans_residual_ff_dropout,
norm_input=args.trans_norm_input,
att_proj_bias=args.trans_att_proj_bias,
shared_dropout=args.trans_shared_dropout,
pre_ln=args.trans_pre_ln
)
self._reset_parameters()
def _reset_parameters(self):
if self.position_table is not None:
torch.nn.init.uniform_(self.position_table, -0.01, 0.01)
def forward(self, features, lengths):
if self.position_table is not None:
# unsqueeze for batch dim
positions = self.position_dropout(self.position_table[:features.shape[1], :].unsqueeze(0))[0]
if self.position_concat:
features = torch.cat([features, positions.expand((features.shape[0], -1, -1))], dim=2)
else:
features = features + positions
output = self.encoder(features, lengths=lengths)
output = output.split([1 for _ in range(len(lengths))])
output = [o.squeeze(0)[:idx] for o, idx in zip(output, lengths)]
return [o[1:-1] for o in output] , output
@staticmethod
def add_cmd_options(cmd):
cmd.add_argument('--trans-dmodel', metavar="", type=int, default=1024, help=" ")
cmd.add_argument('--trans-n-heads', metavar="", type=int, default=8, help=" ")
cmd.add_argument('--trans-n-layers', metavar="", type=int, default=8, help=" ")
cmd.add_argument('--trans-ff-hidden-dim', metavar="", type=int, default=2048, help=" ")
cmd.add_argument('--trans-query-dim', metavar="", type=int, default=64, help=" ")
cmd.add_argument('--trans-value-dim', metavar="", type=int, default=64, help=" ")
cmd.add_argument('--trans-att-dropout', metavar="", type=float, default=0.2, help=" ")
cmd.add_argument('--trans-ff-dropout', metavar="", type=float, default=0.1, help=" ")
cmd.add_argument('--trans-residual-att-dropout', metavar="", type=float, default=0.2, help=" ")
cmd.add_argument('--trans-residual-ff-dropout', metavar="", type=float, default=0.1, help=" ")
cmd.add_argument('--trans-att-proj-bias', action="store_true", default=False, help=" ")
cmd.add_argument('--trans-shared-dropout', action="store_true", default=False, help=" ")
cmd.add_argument('--trans-no-norm-input', action="store_false", default=True, help=" ")
cmd.add_argument('--trans-no-pre-ln', action="store_false", default=True, help=" ")
cmd.add_argument('--trans-no-position-embs', action="store_false", default=True, help=" ")
cmd.add_argument('--trans-position-concat', action="store_true", default=False, help=" ")
cmd.add_argument('--trans-position-max', metavar="", type=int, default=300, help=" ")
cmd.add_argument('--trans-position-dropout', metavar="", type=float, default=0.0, help=" ")
class SentenceEncoderLSTM(nn.Module):
def __init__(self, args, d_input):
super(SentenceEncoderLSTM, self).__init__()
self.args = args
self.word_transducer_l1 = nn.LSTM(input_size=d_input,
hidden_size=args.lstm_dim //2,
num_layers=1,
batch_first=True,
bidirectional=True)
assert(args.lstm_layers >= 2)
self.word_transducer_l2 = nn.LSTM(input_size=args.lstm_dim,
hidden_size=args.lstm_dim //2,
num_layers=args.lstm_layers -1,
batch_first=True,
bidirectional=True)
self.residual = args.lstm_residual_connection
self.ln = [int(i) for i in args.lstm_layer_norm]
self.layer_norm_l0 = nn.LayerNorm(d_input)
self.layer_norm_l1 = nn.LayerNorm(args.lstm_dim)
self.layer_norm_l2 = nn.LayerNorm(args.lstm_dim)
def forward(self, all_embeddings, lengths):
if self.ln[0]:
all_embeddings = [self.layer_norm_l0(e) for e in all_embeddings]
packed_padded_char_based_embeddings = torch.nn.utils.rnn.pack_padded_sequence(
all_embeddings, lengths, batch_first=True)
output_l1, (h_n, c_n) = self.word_transducer_l1(packed_padded_char_based_embeddings)
output_l2, (h_n, c_n) = self.word_transducer_l2(output_l1)
unpacked_l1, _ = torch.nn.utils.rnn.pad_packed_sequence(output_l1, batch_first=True)
unpacked_l2, _ = torch.nn.utils.rnn.pad_packed_sequence(output_l2, batch_first=True)
if self.residual:
unpacked_l2 = unpacked_l2 + unpacked_l1
unpacked_l1 = [t.squeeze(0) for t in unpacked_l1.split([1 for l in lengths], dim=0)]
unpacked_l1 = [t[1:l-1,:] for t, l in zip(unpacked_l1, lengths)]
if self.ln[1]:
unpacked_l1 = [self.layer_norm_l1(l1) for l1 in unpacked_l1]
unpacked_l2 = [t.squeeze(0) for t in unpacked_l2.split([1 for l in lengths], dim=0)]
unpacked_l2 = [t[:l,:] for t, l in zip(unpacked_l2, lengths)]
if self.ln[2]:
unpacked_l2 = [self.layer_norm_l2(l2) for l2 in unpacked_l2]
return unpacked_l1, unpacked_l2
@staticmethod
def add_cmd_options(cmd):
cmd.add_argument("--lstm-dim", "-W", metavar="d", type=int, default=400, help="Dimension of sentence bi-LSTM")
cmd.add_argument("--lstm-layers", "-P", metavar="l", type=int, default=2, help="Depth of word transducer, min=2")
cmd.add_argument("--lstm-residual-connection", metavar=" ", type=bool, default=True, help="Add residual connections between LSTM layers")
cmd.add_argument("--lstm-layer-norm", type=str, choices=["000", "001", "010", "100", "110", "101", "011", "111"], default="001",
help="Add layer normalizations at <input> <output l1> <output l2>")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
TransformerNetwork.add_cmd_options(parser)
args = parser.parse_args()
d_input = 1024
model = TransformerNetwork(args, d_input)
args.n_heads = 2
args.n_layers = 2
input = torch.rand(3, 10, d_input)
output, output = model(input, lengths = torch.tensor([4, 8, 10]))
print([o.shape for o in output])