https://github.com/webstorms/NeuralPred
Revision 0409f09ba6537b3c19d4103a144301929c972c9b authored by Luke Taylor on 07 October 2023, 15:12:52 UTC, committed by Luke Taylor on 07 October 2023, 15:12:52 UTC
0 parent
Tip revision: 0409f09ba6537b3c19d4103a144301929c972c9b authored by Luke Taylor on 07 October 2023, 15:12:52 UTC
init
init
Tip revision: 0409f09
neuron_to_tensor.py
import os
import argparse
import torch
from src import util
# Build em response tensors
def build_unit_activity_tensor(dataset):
unit_activity_list = []
for i, (_, y) in enumerate(dataset):
if i % 10000 == 0:
print(f"i={i}/{len(dataset)}")
unit_activity_list.append(y)
return torch.stack(unit_activity_list)
def main(args):
# Create directory to store all tensors
unit_activity_list = []
if args.dataset == "multi_pvc1":
for i in range(util.get_dataset_number_of_neurons(args.dataset)):
print(f"Building unit activity {i}...")
dataset = util.get_dataset(args.root, name=args.dataset, split=args.split, ntau=args.ntau, idx=i)
unit_activity_tensor = build_unit_activity_tensor(dataset)
print(unit_activity_tensor.shape)
unit_activity_list.append(unit_activity_tensor)
all_unit_activity_tensor = torch.stack(unit_activity_list)
else:
dataset = util.get_dataset(args.root, name=args.dataset, split=args.split, ntau=args.ntau, idx=None)
all_unit_activity_tensor = build_unit_activity_tensor(dataset)
if args.dataset == "cadena":
all_unit_activity_tensor = all_unit_activity_tensor.permute(1, 0, 2)
save_path = f"{args.root}/data/y/"
unit_tensor_name = os.path.join(save_path, f"{util.get_dataset_y_name(args.dataset, args.split, args.ntau)}.pt")
torch.save(all_unit_activity_tensor.float(), unit_tensor_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--root", type=str)
parser.add_argument("--dataset", type=str)
parser.add_argument("--split", type=str)
parser.add_argument("--ntau", type=int)
args = parser.parse_args()
main(args)

Computing file changes ...