https://github.com/webstorms/NeuralPred
Tip revision: 9eccfe575524736f25f4f3759fe04648f22554de authored by Luke Taylor on 09 April 2024, 14:00:54 UTC
Added brainbox version.
Added brainbox version.
Tip revision: 9eccfe5
readout.py
import torch.nn as nn
import torch.nn.functional as F
from brainbox import models
class Readout(models.BBModel):
def __init__(self, n_in, n_out):
super().__init__()
self._n_in = n_in
self._n_out = n_out
self._linear = nn.Linear(n_in, n_out)
self.init_weight(self._linear.weight, "glorot_uniform")
@property
def hyperparams(self):
return {**super().hyperparams, "n_in": self._n_in, "n_out": self._n_out}
def get_params(self):
return [self._linear.weight]
def forward(self, x):
# x: b x n
x = self._linear(x)
x = F.softplus(x, 1, 20)
return x
