Skip to main content
  • Home
  • Development
  • Documentation
  • Donate
  • Operational login
  • Browse the archive

swh logo
SoftwareHeritage
Software
Heritage
Archive
Features
  • Search

  • Downloads

  • Save code now

  • Add forge now

  • Help

  • e290e55
  • /
  • nade.py
Raw File Download

To reference or cite the objects present in the Software Heritage archive, permalinks based on SoftWare Hash IDentifiers (SWHIDs) must be used.
Select below a type of object currently browsed in order to display its associated SWHID and permalink.

  • content
  • directory
content badge
swh:1:cnt:67e5bef4150cdf771c18aa690350473d4a10d583
directory badge
swh:1:dir:e290e55a07697f47f9a18a6f6afb695f1b1683b4

This interface enables to generate software citations, provided that the root directory of browsed objects contains a citation.cff or codemeta.json file.
Select below a type of object currently browsed in order to generate citations for them.

  • content
  • directory
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
Generate software citation in BibTex format (requires biblatex-software package)
Generating citation ...
nade.py
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.distributions import Distribution
from torch.distributions import Categorical

"""
A Discrete NADE, as a torch Distribution object
"""
class DiscreteNADEDistribution(Distribution):

    def __init__(self, data_size, data_domain_sizes, hidden_size, w, v, c):
        self.data_size = data_size
        self.data_domain_sizes = data_domain_sizes
        self.hidden_size = hidden_size
        self.w = w
        self.v = v
        self.c = c
        self.temperature = 1

    def set_temperature(self, temp):
        self.temperature = temp

    def log_prob(self, x):
        log_probs = []

        batch_size = x.size()[0]
        
        # Make this have same batch size as x
        a = self.c
        a = a.unsqueeze(0)
        a = a.expand(batch_size, a.size()[1])

        for i in range(0, self.data_size):
            h = F.sigmoid(a)
            probs = F.softmax(self.v[i](h), 1)
            dist = Categorical(probs)
            val = x[:, i]
            lp = dist.log_prob(val)
            log_probs.append(lp.unsqueeze(1))
            normalized_val = val.float() / self.data_domain_sizes[i]  # Normalize to [0,1]
            normalized_val = normalized_val.unsqueeze(1)
            if i < self.data_size - 1:
                a = self.w[i](normalized_val) + a


        log_prob = torch.sum(torch.cat(log_probs, 1), 1)    # First dimension is batch dim

        return log_prob


    def sample(self):
        return self.sample_n(1)

    def sample_n(self, n):
        outputs = []

        batch_size = n
        
        # Make this have batch size
        a = self.c
        a = a.unsqueeze(0)
        a = a.expand(batch_size, a.size()[1])

        for i in range(0, self.data_size):
            h = F.sigmoid(a)
            logits = self.v[i](h)
            probs = F.softmax(logits / self.temperature, 1)
            dist = Categorical(probs)
            val = dist.sample().unsqueeze(1)
            outputs.append(val)
            normalized_val = val.float() / self.data_domain_sizes[i]  # Normalize to [0,1]
            if i < self.data_size - 1:
                a = self.w[i](normalized_val) + a

        outputs = torch.cat(outputs, 1)                        # First dimension is batch dim

        return outputs


"""
A Discrete NADE, as a nn Module
"""
class DiscreteNADEModule(nn.Module):

    def __init__(self, data_size, data_domain_sizes, hidden_size):
        super(DiscreteNADEModule, self).__init__()

        self.data_size = data_size
        self.data_domain_sizes = data_domain_sizes
        self.hidden_size = hidden_size

        # The initial bias
        self.c = Parameter(torch.Tensor(hidden_size))
        self.reset_parameters()

        # Initialize one linear module for every step of the first layer
        # (Need ModuleList to make automatic parameter registration work)
        self.w = nn.ModuleList()
        for i in range(0, data_size - 1):
            self.w.append(nn.Linear(1, hidden_size, bias=False))

        # Initialize one linear module for every step of the second layer
        self.v = nn.ModuleList()
        for i in range(0, data_size):
            domain_size = data_domain_sizes[i]
            self.v.append(nn.Linear(hidden_size, domain_size, bias=True))

        self.temperature = 1

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.hidden_size)
        self.c.data.uniform_(-stdv, stdv)

    def set_temperature(self, temp):
        self.temperature = temp

    def forward(self, x):
        return self.log_prob(x)

    def log_prob(self, x):
        nade = DiscreteNADEDistribution(self.data_size, self.data_domain_sizes, self.hidden_size,
            self.w, self.v, self.c)
        return nade.log_prob(x)

    def sample(self):
        nade = DiscreteNADEDistribution(self.data_size, self.data_domain_sizes, self.hidden_size,
            self.w, self.v, self.c)
        nade.set_temperature(self.temperature)
        return nade.sample()

    def sample_n(self, n):
        nade = DiscreteNADEDistribution(self.data_size, self.data_domain_sizes, self.hidden_size,
            self.w, self.v, self.c)
        nade.set_temperature(self.temperature)
        return nade.sample_n(n)


# ##### TEST
# n = DiscreteNADEModule(10, [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], 5)
# # inp = torch.autograd.Variable(torch.LongTensor(4, 10).fill_(1))
# # lp = n(inp)
# # print(lp)
# out = n.sample_n(4)
# print(out)
# ##### END TEST

back to top

Software Heritage — Copyright (C) 2015–2026, The Software Heritage developers. License: GNU AGPLv3+.
The source code of Software Heritage itself is available on our development forge.
The source code files archived by Software Heritage are available under their own copyright and licenses.
Terms of use: Archive access, API— Content policy— Contact— JavaScript license information— Web API