nade.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.distributions import Distribution
from torch.distributions import Categorical
"""
A Discrete NADE, as a torch Distribution object
"""
class DiscreteNADEDistribution(Distribution):
def __init__(self, data_size, data_domain_sizes, hidden_size, w, v, c):
self.data_size = data_size
self.data_domain_sizes = data_domain_sizes
self.hidden_size = hidden_size
self.w = w
self.v = v
self.c = c
self.temperature = 1
def set_temperature(self, temp):
self.temperature = temp
def log_prob(self, x):
log_probs = []
batch_size = x.size()[0]
# Make this have same batch size as x
a = self.c
a = a.unsqueeze(0)
a = a.expand(batch_size, a.size()[1])
for i in range(0, self.data_size):
h = F.sigmoid(a)
probs = F.softmax(self.v[i](h), 1)
dist = Categorical(probs)
val = x[:, i]
lp = dist.log_prob(val)
log_probs.append(lp.unsqueeze(1))
normalized_val = val.float() / self.data_domain_sizes[i] # Normalize to [0,1]
normalized_val = normalized_val.unsqueeze(1)
if i < self.data_size - 1:
a = self.w[i](normalized_val) + a
log_prob = torch.sum(torch.cat(log_probs, 1), 1) # First dimension is batch dim
return log_prob
def sample(self):
return self.sample_n(1)
def sample_n(self, n):
outputs = []
batch_size = n
# Make this have batch size
a = self.c
a = a.unsqueeze(0)
a = a.expand(batch_size, a.size()[1])
for i in range(0, self.data_size):
h = F.sigmoid(a)
logits = self.v[i](h)
probs = F.softmax(logits / self.temperature, 1)
dist = Categorical(probs)
val = dist.sample().unsqueeze(1)
outputs.append(val)
normalized_val = val.float() / self.data_domain_sizes[i] # Normalize to [0,1]
if i < self.data_size - 1:
a = self.w[i](normalized_val) + a
outputs = torch.cat(outputs, 1) # First dimension is batch dim
return outputs
"""
A Discrete NADE, as a nn Module
"""
class DiscreteNADEModule(nn.Module):
def __init__(self, data_size, data_domain_sizes, hidden_size):
super(DiscreteNADEModule, self).__init__()
self.data_size = data_size
self.data_domain_sizes = data_domain_sizes
self.hidden_size = hidden_size
# The initial bias
self.c = Parameter(torch.Tensor(hidden_size))
self.reset_parameters()
# Initialize one linear module for every step of the first layer
# (Need ModuleList to make automatic parameter registration work)
self.w = nn.ModuleList()
for i in range(0, data_size - 1):
self.w.append(nn.Linear(1, hidden_size, bias=False))
# Initialize one linear module for every step of the second layer
self.v = nn.ModuleList()
for i in range(0, data_size):
domain_size = data_domain_sizes[i]
self.v.append(nn.Linear(hidden_size, domain_size, bias=True))
self.temperature = 1
def reset_parameters(self):
stdv = 1. / math.sqrt(self.hidden_size)
self.c.data.uniform_(-stdv, stdv)
def set_temperature(self, temp):
self.temperature = temp
def forward(self, x):
return self.log_prob(x)
def log_prob(self, x):
nade = DiscreteNADEDistribution(self.data_size, self.data_domain_sizes, self.hidden_size,
self.w, self.v, self.c)
return nade.log_prob(x)
def sample(self):
nade = DiscreteNADEDistribution(self.data_size, self.data_domain_sizes, self.hidden_size,
self.w, self.v, self.c)
nade.set_temperature(self.temperature)
return nade.sample()
def sample_n(self, n):
nade = DiscreteNADEDistribution(self.data_size, self.data_domain_sizes, self.hidden_size,
self.w, self.v, self.c)
nade.set_temperature(self.temperature)
return nade.sample_n(n)
# ##### TEST
# n = DiscreteNADEModule(10, [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], 5)
# # inp = torch.autograd.Variable(torch.LongTensor(4, 10).fill_(1))
# # lp = n(inp)
# # print(lp)
# out = n.sample_n(4)
# print(out)
# ##### END TEST