Revision 864e30807f6988731ac3b4b98af6562c18bb42ff authored by Srinivasan Iyer on 27 January 2021, 00:44:34 UTC, committed by GitHub on 27 January 2021, 00:44:34 UTC
2 parent s 684de0e + 7c0c5b2
Raw File
CopyGenerator.py
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.autograd import Variable
from UtilClass import shiftLeft, bottle, unbottle

def aeq(*args):
    """
    Assert all arguments have the same value
    """
    arguments = (arg for arg in args)
    first = next(arguments)
    assert all(arg == first for arg in arguments), \
        "Not all arguments have the same value: " + str(args)

class RegularGenerator(nn.Module):
  def __init__(self, rnn_size, vocabs, opt):
    super(RegularGenerator, self).__init__()
    self.opt = opt
    self.linear = nn.Linear(rnn_size, len(vocabs['code']))
    self.lsm = nn.LogSoftmax(dim=1)
    self.vocabs = vocabs
    self.target_padding_idx = vocabs['code'].stoi['<blank>']

    weight = torch.ones(len(vocabs['code']))
    weight[vocabs['code'].stoi['<blank>']] = 0
    self.criterion = nn.NLLLoss(weight, size_average=False)

  def forward(self, hidden, attn, src_map, batch):
    return self.lsm(self.linear(hidden))

  def computeLoss(self, scores, batch):
    target = Variable(shiftLeft(batch['code'].cuda(), self.vocabs['code'].stoi['<blank>']).view(-1), requires_grad=False)
    loss = self.criterion(scores, target)
    scores_data = scores.data.clone()
    target_data = target.data.clone() #computeLoss populates this

    pred = scores_data.max(1)[1]
    non_padding = target_data.ne(self.target_padding_idx)
    num_correct = pred.eq(target_data).masked_select(non_padding).sum()
    return loss, non_padding.sum(), num_correct

class ProdGenerator(nn.Module):
  def __init__(self, rnn_size, vocabs, opt):
    super(ProdGenerator, self).__init__()
    self.opt = opt
    self.mask = Variable(vocabs['mask'].float().cuda(), requires_grad=False)
    self.linear = nn.Linear(rnn_size , len(vocabs['next_rules']))  # only non unk rules
    self.linear_copy = nn.Linear(rnn_size, 1)
    self.tgt_pad = vocabs['next_rules'].stoi['<blank>']
    self.tgt_unk = vocabs['next_rules'].stoi['<unk>']
    self.vocabs = vocabs

  def forward(self, hidden, attn, src_map, batch):
    out = self.linear(hidden)
    # batch['nt'] contains padding. 
    batch_by_tlen_, slen = attn.size()
    batch_size, slen_, cvocab = src_map.size()

    non_terminals = batch['nt'].contiguous().cuda().view(-1)
    masked_out = torch.add(out, torch.index_select(self.mask, 0, Variable(non_terminals, requires_grad=False)))
    prob = F.softmax(masked_out, dim=1)

    # Probability of copying p(z=1) batch.
    copy = F.sigmoid(self.linear_copy(hidden))

    # Probibility of not copying: p_{word}(w) * (1 - p(z))
    masked_copy = Variable(non_terminals.cuda().view(-1, 1).eq(self.vocabs['nt'].stoi['IdentifierNT']).float()) * copy
    out_prob = torch.mul(prob,  1 - masked_copy.expand_as(prob)) # The ones without IdentifierNT are left untouched
    mul_attn = torch.mul(attn, masked_copy.expand_as(attn))
    copy_prob = torch.bmm(mul_attn.view(batch_size, -1, slen), Variable(src_map.cuda(), requires_grad=False))
    copy_prob = copy_prob.view(-1, cvocab) # bottle it again to get batch_by_len times cvocab
    return torch.cat([out_prob, copy_prob], 1) # batch_by_tlen x (out_vocab + cvocab)

  def computeLoss(self, scores, batch):

    batch_size = batch['seq2seq'].size(0)

    target = Variable(batch['next_rules'].contiguous().cuda().view(-1), requires_grad=False)
    if self.opt.decoder_type == "prod":
      align = Variable(batch['next_rules_in_src_nums'].contiguous().cuda().view(-1), requires_grad=False)
      align_unk = batch['seq2seq_vocab'][0].stoi['<unk>']
    elif self.opt.decoder_type in ["concode"]:
      align = Variable(batch['concode_next_rules_in_src_nums'].contiguous().cuda().view(-1), requires_grad=False)
      align_unk = batch['concode_vocab'][0].stoi['<unk>']

    offset = len(self.vocabs['next_rules'])

    out = scores.gather(1, align.view(-1, 1) + offset).view(-1).mul(align.ne(align_unk).float()) # all where copy is not unk
    tmp = scores.gather(1, target.view(-1, 1)).view(-1)

    out = out + 1e-20 + tmp.mul(target.ne(self.tgt_unk).float()) + \
                  tmp.mul(align.eq(align_unk).float()).mul(target.eq(self.tgt_unk).float()) # copy and target are unks

        # Drop padding.
    loss = -out.log().mul(target.ne(self.tgt_pad).float()).sum()
    scores_data = scores.data.clone()
    target_data = target.data.clone() #computeLoss populates this

    scores_data = self.collapseCopyScores(unbottle(scores_data, batch_size), batch)
    scores_data = bottle(scores_data)

    # Correct target copy token instead of <unk>
    # tgt[i] = align[i] + len(tgt_vocab)
    # for i such that tgt[i] == 0 and align[i] != 0
    # when target is <unk> but can be copied, make sure we get the copy index right
    correct_mask = target_data.eq(self.tgt_unk) * align.data.ne(align_unk)
    correct_copy = (align.data + offset) * correct_mask.long()
    target_data = (target_data * (1 - correct_mask).long()) + correct_copy

    pred = scores_data.max(1)[1]
    non_padding = target_data.ne(self.tgt_pad)
    num_correct = pred.eq(target_data).masked_select(non_padding).sum()

    return loss, non_padding.sum(), num_correct #, stats

  def collapseCopyScores(self, scores, batch):
    """
    Given scores from an expanded dictionary
    corresponding to a batch, sums together copies,
    with a dictionary word when it is ambigious.
    """
    tgt_vocab = self.vocabs['next_rules']
    offset = len(tgt_vocab)
    for b in range(batch['seq2seq'].size(0)):
      if self.opt.decoder_type == "prod":
        src_vocab = batch['seq2seq_vocab'][b]
      elif self.opt.decoder_type in ["concode"]:
        src_vocab = batch['concode_vocab'][b]

      for i in range(1, len(src_vocab)):
        sw = "IdentifierNT-->" + src_vocab.itos[i]
        ti = tgt_vocab.stoi[sw] if sw in tgt_vocab.stoi else self.tgt_unk
        if ti != self.tgt_unk:
          scores[b, :, ti] += scores[b, :, offset + i]
          scores[b, :, offset + i].fill_(1e-20)
    return scores

class CopyGenerator(nn.Module):
    """
    Generator module that additionally considers copying
    words directly from the source.
    """
    def __init__(self, rnn_size, vocabs, opt):
        super(CopyGenerator, self).__init__()
        self.opt = opt
        self.tgt_dict_size = len(vocabs['code'])
        self.tgt_padding_idx = vocabs['code'].stoi['<blank>']
        self.tgt_unk_idx = vocabs['code'].stoi['<unk>']
        self.vocabs = vocabs
        self.linear = nn.Linear(rnn_size, self.tgt_dict_size)
        self.linear_copy = nn.Linear(rnn_size, 1)
        force_copy=False
        self.criterion = CopyGeneratorCriterion(self.tgt_dict_size, force_copy, self.tgt_padding_idx, self.tgt_unk_idx)

    def forward(self, hidden, copy_attn, src_map, batch):
        """
        Computes p(w) = p(z=1) p_{copy}(w|z=0)  +  p(z=0) * p_{softmax}(w|z=0)
        """
        # CHECKS
        batch_by_tlen, _ = hidden.size()
        batch_by_tlen_, slen = copy_attn.size()
        batch_size, slen_, cvocab = src_map.size()
        aeq(batch_by_tlen, batch_by_tlen_)
        aeq(slen, slen_)

        # Original probabilities.
        logits = self.linear(hidden)
        logits[:, self.tgt_padding_idx] = -float('inf')
        prob = F.softmax(logits, dim=1)

        # Probability of copying p(z=1) batch.
        copy = F.sigmoid(self.linear_copy(hidden))

        # Probibility of not copying: p_{word}(w) * (1 - p(z))
        out_prob = torch.mul(prob,  1 - copy.expand_as(prob))
        mul_attn = torch.mul(copy_attn, copy.expand_as(copy_attn))
        copy_prob = torch.bmm(mul_attn.view(batch_size, -1, slen), Variable(src_map.cuda(), requires_grad=False))
        copy_prob = copy_prob.view(-1, cvocab) # bottle it again to get batch_by_len times cvocab
        return torch.cat([out_prob, copy_prob], 1) # batch_by_tlen x (out_vocab + cvocab)

    def computeLoss(self, scores, batch):
        """
        Args:
            batch: the current batch.
            target: the validate target to compare output with.
            align: the align info.
        """
        batch_size = batch['seq2seq'].size(0)

        self.target = Variable(shiftLeft(batch['code'].cuda(), self.tgt_padding_idx).view(-1), requires_grad=False)

        align = Variable(shiftLeft(batch['code_in_src_nums'].cuda(),  self.vocabs['seq2seq'].stoi['<blank>']).view(-1), requires_grad=False)
        # All individual vocabs have the same unk index
        align_unk = batch['seq2seq_vocab'][0].stoi['<unk>']
        loss = self.criterion(scores, self.target, align, align_unk)

        scores_data = scores.data.clone()
        target_data = self.target.data.clone() #computeLoss populates this

        if self.opt.copy_attn:
          scores_data = self.collapseCopyScores(unbottle(scores_data, batch_size), batch)
          scores_data = bottle(scores_data)

          # Correct target copy token instead of <unk>
          # tgt[i] = align[i] + len(tgt_vocab)
          # for i such that tgt[i] == 0 and align[i] != 0
          # when target is <unk> but can be copied, make sure we get the copy index right
          correct_mask = target_data.eq(self.tgt_unk_idx) * align.data.ne(align_unk)
          correct_copy = (align.data + self.tgt_dict_size) * correct_mask.long()
          target_data = (target_data * (1 - correct_mask).long()) + correct_copy


        pred = scores_data.max(1)[1]
        non_padding = target_data.ne(self.tgt_padding_idx)
        num_correct = pred.eq(target_data).masked_select(non_padding).sum()

        return loss, non_padding.sum(), num_correct #, stats

    def collapseCopyScores(self, scores, batch):
      """
      Given scores from an expanded dictionary
      corresponding to a batch, sums together copies,
      with a dictionary word when it is ambigious.
      """
      tgt_vocab = self.vocabs['code']
      offset = len(tgt_vocab)
      for b in range(batch['seq2seq'].size(0)):
        src_vocab = batch['seq2seq_vocab'][b]
        for i in range(1, len(src_vocab)):
          sw = src_vocab.itos[i]
          ti = tgt_vocab.stoi[sw] if sw in tgt_vocab.stoi else self.tgt_unk_idx
          if ti != self.tgt_unk_idx:
            scores[b, :, ti] += scores[b, :, offset + i]
            scores[b, :, offset + i].fill_(1e-20)
      return scores


class CopyGeneratorCriterion(object):
    def __init__(self, vocab_size, force_copy, tgt_pad, tgt_unk, eps=1e-20):
        self.force_copy = force_copy
        self.eps = eps
        self.offset = vocab_size
        self.tgt_pad = tgt_pad
        self.tgt_unk = tgt_unk

    def __call__(self, scores, target, align, copy_unk):
        # Copy prob.
        out = scores.gather(1, align.view(-1, 1) + self.offset) \
                    .view(-1).mul(align.ne(copy_unk).float())
        tmp = scores.gather(1, target.view(-1, 1)).view(-1)

        # Regular prob (no unks and unks that can't be copied)
        if not self.force_copy:
            # first one = target is not unk
            out = out + self.eps + tmp.mul(target.ne(self.tgt_unk).float()) + \
                  tmp.mul(align.eq(copy_unk).float()).mul(target.eq(self.tgt_unk).float()) # copy and target are unks
        else:
            # Forced copy.
            out = out + self.eps + tmp.mul(align.eq(0).float())

        # Drop padding.
        loss = -out.log().mul(target.ne(self.tgt_pad).float()).sum()
        return loss
back to top