https://github.com/webstorms/NeuralPred
Tip revision: 1484b1ae509bf58a2cc2f711e525fd1d225b9b79 authored by Luke Taylor on 07 October 2023, 15:17:28 UTC
typo fix
typo fix
Tip revision: 1484b1a
util.py
import torch
from src import datasets, models
def get_dataset(root, name, split, ntau, idx, normalise=True, scale=1):
if name == "single_pvc1":
dataset = datasets.SinglePVC1Dataset(root, split, ntau, normalise, scale)
elif name == "multi_pvc1":
dataset = datasets.MultiPVC1Dataset(root, split, ntau, idx, normalise, scale)
elif name == "cadena":
dataset = datasets.CadenaDataset(root, split, ntau, normalise, scale)
else:
raise NotImplementedError
return dataset
def get_dataset_max_ccs(name, type="ccnorm"):
assert type in ["spe", "ccnorm"]
if name == "single_pvc1":
spe, ccnorm = 1, 1
elif name == "multi_pvc1":
spe = [0.5744946002960205, 0.4064042866230011, 0.0763465166091919, 0.6352482438087463, 0.905774712562561, 0.8341729640960693, 0.17452464997768402, 0.06605248898267746, 0.8186521530151367, 0.0787293091416359, 0.4269949793815613, -0.015010775066912174, 0.5089885592460632, 0.014189318753778934, 0.3708527386188507, 0.37515363097190857, 0.2074914276599884, 0.651958703994751, 0.587609589099884, 0.7997542023658752, 0.19506847858428955, 0.1626891940832138, 0.3523372411727905]
ccnorm = [0.737077534198761, 0.6249262094497681, 0.27525997161865234, 0.772854745388031, 0.9113429188728333, 0.8774657845497131, 0.41416308283805847, 0.25616225600242615, 0.8698875308036804, 0.2794893682003021, 0.6399289965629578, 0.001, 0.6959428191184998, 0.11903421580791473, 0.5979894399642944, 0.6013222932815552, 0.45085883140563965, 0.7823395729064941, 0.7449815273284912, 0.860540509223938, 0.4374198913574219, 0.4001060128211975, 0.5833914875984192]
elif name == "cadena":
spe = [0.8294596672058105, 0.7743522524833679, 0.8139088153839111, 0.8523905873298645, 0.8834021687507629, 0.8405280709266663, 0.7888814806938171, 0.8806656002998352, 0.7885940670967102, 0.7575450539588928, 0.8697508573532104, 0.7821177840232849, 0.8904778957366943, 0.7729043364524841, 0.8723242878913879, 0.765550971031189, 0.8104876279830933, 0.732178270816803, 0.7408490777015686, 0.8157734870910645, 0.8091776371002197, 0.847025454044342, 0.7669594883918762, 0.7741407155990601, 0.8356441855430603, 0.7774149179458618, 0.7893873453140259, 0.7886411547660828, 0.8594353795051575, 0.7597140073776245, 0.7696136832237244, 0.8762743473052979, 0.8709105849266052, 0.8742618560791016, 0.8372071981430054, 0.8356213569641113, 0.7400267720222473, 0.8121872544288635, 0.8723312616348267, 0.7857792377471924, 0.7779542803764343, 0.8322813510894775, 0.7660675644874573, 0.8069775700569153, 0.8050919771194458, 0.7506529092788696, 0.8469997644424438, 0.8157544136047363, 0.8388168811798096, 0.8575052618980408, 0.8645200729370117, 0.9168898463249207, 0.6092856526374817, 0.6570807099342346, 0.6016409397125244, 0.5625165700912476, 0.6352408528327942, 0.7155857682228088, 0.6412456631660461, 0.7296238541603088, 0.6927119493484497, 0.7414397597312927, 0.6199454069137573, 0.5255076885223389, 0.440677672624588, 0.6527426838874817, 0.45821037888526917, 0.5625847578048706, 0.5461782813072205, 0.5486024022102356, 0.6481673717498779, 0.7279085516929626, 0.5370447039604187, 0.6637105941772461, 0.6784877181053162, 0.8297263979911804, 0.5310347080230713, 0.5434422492980957, 0.7023946642875671, 0.7186911106109619, 0.5459261536598206, 0.8798953890800476, 0.6446630358695984, 0.6411856412887573, 0.8978777527809143, 0.5424546599388123, 0.9083402156829834, 0.5660972595214844, 0.6976428627967834, 0.6164090037345886, 0.5813226103782654, 0.860040545463562, 0.7388808131217957, 0.8128939270973206, 0.8266680240631104, 0.6176729202270508, 0.732837438583374, 0.5497377514839172, 0.6105810403823853, 0.4265868663787842, 0.7522940635681152, 0.5391196608543396, 0.7310096025466919, 0.6399711966514587, 0.5078704953193665, 0.7312605977058411, 0.7765926122665405, 0.735821545124054, 0.5001757740974426, 0.7441239953041077, 0.42179569602012634, 0.5614448189735413, 0.48709818720817566, 0.5280357599258423, 0.45917248725891113, 0.5348649621009827, 0.5594353079795837, 0.5760141015052795, 0.43652766942977905, 0.5337889194488525, 0.7033160924911499, 0.5131003856658936, 0.7617539167404175, 0.5105043649673462, 0.4581662118434906, 0.6627364158630371, 0.5795906782150269, 0.5743682384490967, 0.6536250710487366, 0.4129616916179657, 0.4557209610939026, 0.5338262915611267, 0.6286768913269043, 0.6113972663879395, 0.7735698819160461, 0.7080456018447876, 0.7765780091285706, 0.6732655167579651, 0.5791230201721191, 0.7513883709907532, 0.40734049677848816, 0.6913051605224609, 0.8259644508361816, 0.3911234736442566, 0.5535159707069397, 0.6776208877563477, 0.5250111222267151, 0.5638331174850464, 0.742428719997406, 0.6347155570983887, 0.6643922924995422, 0.6903847455978394, 0.5802809000015259, 0.7631508708000183, 0.5685293078422546, 0.4325319230556488, 0.5558401942253113, 0.5836763381958008, 0.5684800148010254, 0.4896102547645569, 0.6317916512489319, 0.5330033302307129, 0.6986749172210693, 0.8249865770339966, 0.802408754825592, 0.5049863457679749]
ccnorm = [0.8288546204566956, 0.8054563999176025, 0.8223732113838196, 0.8382458686828613, 0.8506442904472351, 0.8334119319915771, 0.8117435574531555, 0.849563717842102, 0.8116199970245361, 0.7980732917785645, 0.8452283143997192, 0.8088275194168091, 0.8534260392189026, 0.8048250675201416, 0.846254289150238, 0.8016051650047302, 0.8209347128868103, 0.7866976857185364, 0.7906181216239929, 0.8231552243232727, 0.8203827738761902, 0.8360660076141357, 0.8022236227989197, 0.805364191532135, 0.8314067125320435, 0.8067889213562012, 0.8119608759880066, 0.811640202999115, 0.8410924077033997, 0.7990328669548035, 0.803386926651001, 0.8478245139122009, 0.8456910252571106, 0.8470252156257629, 0.8320493698120117, 0.8313973546028137, 0.7902478575706482, 0.821649968624115, 0.8462570905685425, 0.8104084134101868, 0.8070232272148132, 0.8300207257270813, 0.8018321394920349, 0.8194541335105896, 0.818656861782074, 0.7950106263160706, 0.8360555768013, 0.8231472373008728, 0.8327103853225708, 0.8403142690658569, 0.8431357741355896, 0.8636611104011536, 0.7271494269371033, 0.7512465119361877, 0.723173201084137, 0.7022560238838196, 0.7403942942619324, 0.7790996432304382, 0.7434040307998657, 0.7855361700057983, 0.7684124112129211, 0.7908840179443359, 0.7326360940933228, 0.6815314292907715, 0.6300367712974548, 0.749111533164978, 0.6411833167076111, 0.7022932767868042, 0.6932246088981628, 0.6945760846138, 0.7468488216400146, 0.7847546339035034, 0.6880954504013062, 0.754490077495575, 0.7616373896598816, 0.8289649486541748, 0.6846880316734314, 0.6916942596435547, 0.7729672193527222, 0.7805314064025879, 0.6930838227272034, 0.8492591977119446, 0.7451080679893494, 0.7433740496635437, 0.8563171625137329, 0.6911405920982361, 0.860373318195343, 0.7042112350463867, 0.7707375884056091, 0.7308232188224792, 0.7124316096305847, 0.8413360714912415, 0.7897311449050903, 0.821946918964386, 0.8276978731155396, 0.7314720153808594, 0.7869969010353088, 0.6952076554298401, 0.7278196811676025, 0.6208679676055908, 0.7957417964935303, 0.6892658472061157, 0.7861666679382324, 0.7427669763565063, 0.6713063716888428, 0.7862806916236877, 0.8064315319061279, 0.7883490920066833, 0.6667708158493042, 0.7920901775360107, 0.6177058815956116, 0.7016690373420715, 0.6589545607566833, 0.6829780340194702, 0.6417868137359619, 0.6868626475334167, 0.7005665898323059, 0.7095824480056763, 0.6273562908172607, 0.686252772808075, 0.7733982801437378, 0.67436283826828, 0.7999334931373596, 0.6728483438491821, 0.6411555409431458, 0.7540149688720703, 0.7115039825439453, 0.7086954116821289, 0.7495466470718384, 0.611814558506012, 0.6396177411079407, 0.6862739324569702, 0.7370811700820923, 0.7282415628433228, 0.8051153421401978, 0.7756044268608093, 0.8064252734184265, 0.7591243386268616, 0.7112532258033752, 0.7953383922576904, 0.6080236434936523, 0.7677468657493591, 0.8274058699607849, 0.5968965888023376, 0.6973031759262085, 0.761221170425415, 0.6812466979026794, 0.702975869178772, 0.7913287878036499, 0.7401300668716431, 0.7548223733901978, 0.7673109173774719, 0.7118738889694214, 0.8005492091178894, 0.7055343985557556, 0.6247598528862, 0.6985873579978943, 0.7136890888214111, 0.7055076956748962, 0.6604666709899902, 0.7386562824249268, 0.6858070492744446, 0.7712227702140808, 0.8269997239112854, 0.8175197839736938, 0.6696117520332336]
else:
raise NotImplementedError
if type == "ccnorm":
return torch.Tensor(ccnorm)
else:
return torch.Tensor(spe) ** 0.5
def get_valid_neurons(name, type="ccnorm", thresh=0.15):
max_ccs = get_dataset_max_ccs(name, type)
return max_ccs > thresh
def get_dataset_number_of_neurons(name):
if name == "single_pvc1":
return 127
elif name == "multi_pvc1":
return 23
elif name == "cadena":
return 166
else:
raise NotImplementedError
def get_model(root, model_name, ntau, nlat, nspan, layer):
n_warmup = ntau - nspan # Lowers compute time for spatial models on spatio-temporal stimuli
if model_name == "stacktp":
return models.TPModel(root, int(layer))
if model_name == "ext_stacktp":
return models.TPExtModel(root, int(layer))
elif model_name == "random_stacktp":
return models.RandomTPModel(int(layer))
elif model_name.split("_")[0] == "stack":
loss = model_name.split("_")[1]
lam = model_name.split("_")[2]
detach = model_name.split("_")[3]
return models.StackModel(root, loss, lam, int(layer), detach=detach)
elif model_name == "prednet":
return models.PrednetModel(root, layer)
elif model_name == "bwt":
return models.BWTModel(root, n_warmup)
elif model_name == "vgg":
return models.VGGModel(layer, n_warmup)
else:
raise NotImplementedError
def get_dataset_y_name(name, split, ntau):
return f"{name}_{split}_{ntau}"
def get_dataset_x_name(model, dataset, split, ntau, nlat, nspan, scale, n_pca, layer=None):
if layer is None:
return f"{model}_{dataset}_{split}_{ntau}_{nlat}_{nspan}_{scale}_{n_pca}"
else:
return f"{model}_{layer}_{dataset}_{split}_{ntau}_{nlat}_{nspan}_{scale}_{n_pca}"