https://gitlab.com/mcoavoux/mtgpy-release-findings-2021.git
Tip revision: c9972219cd75049269d26632d2bb79619d661298 authored by mcoavoux on 20 May 2021, 13:04:44 UTC
up readme
up readme
Tip revision: c997221
features.py
import torch
def set2tensor(iset, gap):
# Returns a representation of a constituent, i.e. set of indexes as either
# - a tensor of token indexes
# - a tensor of token context-aware embeddings
shift = 5 # shift bc first indexes of embedding matrix are reserved for additional parameters
# 0: default left
# 1: default right
# 2: default left gap
# 3: default right gap
# deprecated: default buffer -> use EOS instead
# 4: start of sentence symbole
mini = min(iset)
maxi = max(iset)
gapi = {i for i in range(mini, maxi+1) if i not in iset}
mingapi = 2
maxgapi = 3
if len(gapi) > 0:
mingapi = min(gapi) + shift
maxgapi = max(gapi) + shift
iset = [mini+shift, maxi+shift, mingapi, maxgapi]
if gap:
return iset
return iset[:2]
def configuration_features_all(device, stack, queue, buffer, sent_len):
s1 = [0, 1]
s0 = [0, 1, 2, 3]
q1 = [0, 1]
q0 = [0, 1, 2, 3]
b0 = [4]
if len(stack) > 0:
s0 = set2tensor(stack[-1], gap=True)
if len(stack) > 1:
s1 = set2tensor(stack[-2], gap=False)
if len(queue) > 0:
q0 = set2tensor(queue[-1], gap=True)
if len(queue) > 1:
q1 = set2tensor(queue[-2], gap=False)
b0 = [buffer]
return torch.tensor(s1+s0+q1+q0+b0, dtype=torch.long, device=device)
def configuration_features_tacl_base(device, stack, queue, buffer, sent_len):
s1 = [0, 1]
s0 = [0, 1]
q1 = [0, 1]
q0 = [0, 1]
if len(stack) > 0:
s0 = set2tensor(stack[-1], gap=False)
if len(stack) > 1:
s1 = set2tensor(stack[-2], gap=False)
if len(queue) > 0:
q0 = set2tensor(queue[-1], gap=False)
if len(queue) > 1:
q1 = set2tensor(queue[-2], gap=False)
return torch.tensor(s1+s0+q1+q0, dtype=torch.long, device=device)
def configuration_features_tacl_buf(device, stack, queue, buffer, sent_len):
s1 = [0, 1]
s0 = [0, 1]
q1 = [0, 1]
q0 = [0, 1]
if len(stack) > 0:
s0 = set2tensor(stack[-1], gap=False)
if len(stack) > 1:
s1 = set2tensor(stack[-2], gap=False)
if len(queue) > 0:
q0 = set2tensor(queue[-1], gap=False)
if len(queue) > 1:
q1 = set2tensor(queue[-2], gap=False)
b0 = [buffer]
return torch.tensor(s1+s0+q1+q0 + b0, dtype=torch.long, device=device)
def configuration_features_global(device, stack, queue, buffer, sent_len):
shift=5
s1 = [0, 1]
s0 = [0, 1]
q1 = [0, 1]
q0 = [0, 1]
if len(stack) > 0:
s0 = set2tensor(stack[-1], gap=False)
if len(stack) > 1:
s1 = set2tensor(stack[-2], gap=False)
if len(queue) > 0:
q0 = set2tensor(queue[-1], gap=False)
if len(queue) > 1:
q1 = set2tensor(queue[-2], gap=False)
b0 = [buffer, shift-1, sent_len+shift]
return torch.tensor(s1+s0+q1+q0 + b0, dtype=torch.long, device=device)
feature_functions = {"all": (configuration_features_all, 13),
"tacl": (configuration_features_tacl_buf, 9),
"tacl_base": (configuration_features_tacl_base, 8),
"global": (configuration_features_global, 11)}