https://github.com/webstorms/NeuralPred
Tip revision: 1484b1ae509bf58a2cc2f711e525fd1d225b9b79 authored by Luke Taylor on 07 October 2023, 15:17:28 UTC
typo fix
typo fix
Tip revision: 1484b1a
readout.py
import torch.nn as nn
import torch.nn.functional as F
from brainbox import models
class Readout(models.BBModel):
def __init__(self, n_in, n_out):
super().__init__()
self._n_in = n_in
self._n_out = n_out
self._linear = nn.Linear(n_in, n_out)
self.init_weight(self._linear.weight, "glorot_uniform")
@property
def hyperparams(self):
return {**super().hyperparams, "n_in": self._n_in, "n_out": self._n_out}
def get_params(self):
return [self._linear.weight]
def forward(self, x):
# x: b x n
x = self._linear(x)
x = F.softplus(x, 1, 20)
return x