https://github.com/webstorms/NeuralPred
Revision 0409f09ba6537b3c19d4103a144301929c972c9b authored by Luke Taylor on 07 October 2023, 15:12:52 UTC, committed by Luke Taylor on 07 October 2023, 15:12:52 UTC
0 parent
Tip revision: 0409f09ba6537b3c19d4103a144301929c972c9b authored by Luke Taylor on 07 October 2023, 15:12:52 UTC
init
init
Tip revision: 0409f09
models.py
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.models as models
from stack import train as stack_train
class StackModel:
def __init__(self, root, loss, lam, output_stack, detach=False):
assert output_stack in [1, 2, 3, 4]
self._output_stack = output_stack
print(f"{loss}_{lam}_{detach}_False loading....")
self._model = stack_train.Trainer.load_model(f"{root}/dependencies/stack", f"{loss}_{lam}_{detach}_False")
self._model.eval()
def __call__(self, x):
# x: b x n x t x h x w
x = x.cuda()
out = self._model(x.unsqueeze(0), layer=self._output_stack)
out = out[0].permute(1, 0, 2, 3).flatten(1, 3)
return out.cpu()
class PrednetModel:
INPUT_LEN = 10 # Seems like this is the cap enforced by the prednet model
def __init__(self, root, output_stack):
assert output_stack in ("E0", "E1", "E2", "E3")
# Load model specific dependencies
sys.path.append(os.path.join(root, "dependencies", "prednet"))
from keras.models import model_from_json
from prednet import PredNet
self.output_stack = output_stack # starts at 'E0'
# Load the prednet model
weights_dir = os.path.join(root, 'dependencies', 'prednet', 'model_data_keras2')
weights_file = os.path.join(weights_dir, 'tensorflow_weights/prednet_kitti_weights.hdf5')
json_file = os.path.join(weights_dir, 'prednet_kitti_model.json')
# Load trained model
f = open(json_file, 'r')
json_string = f.read()
f.close()
prednet_config = model_from_json(json_string, custom_objects={'PredNet': PredNet})
prednet_config.load_weights(weights_file)
# Set the output to the provided layer
layer_config = prednet_config.layers[1].get_config()
layer_config['output_mode'] = output_stack
self.prednet = PredNet(weights=prednet_config.layers[1].get_weights(), **layer_config)
self.model = None
def __call__(self, x):
import keras
from keras.layers import Input
# x: chan x time x height x width
if self.model is None:
channels = 3
h, w = x.shape[2], x.shape[3]
if h == 73 and w == 73: # Little fix to ensure we can base in data scaled at 0.66x
h, w = 80, 80
input_shape = [PrednetModel.INPUT_LEN, channels, h, w]
print(input_shape)
inputs = Input(shape=tuple(input_shape))
predictions = self.prednet(inputs)
self.model = keras.models.Model(inputs=inputs, outputs=predictions)
h, w = x.shape[2], x.shape[3]
if h == 73 and w == 73: # Little fix to ensure we can base in data scaled at 0.66x
x = F.pad(x, (3, 4, 3, 4))
assert x.shape[2] % 2**3 == 0 # A requirement of the prednet model
x = x.repeat(3, 1, 1, 1) # Repeat the grayscale channel three times
x = x.permute(1, 0, 2, 3) # b x c x t x h x w -> b x t x c x h x w
x = x.unsqueeze(0) # Add batch dimension
x = x[:, -PrednetModel.INPUT_LEN:, ] # Only look at INPUT_LEN last frames
x = self.model.predict(x.numpy(), PrednetModel.INPUT_LEN) # Get hidden activity
x = torch.from_numpy(x[0]).flatten(start_dim=1, end_dim=-1) # time x neurons
return x
class ImgModelBase:
def __init__(self, n_warmup=None):
self.n_warmup = n_warmup
def __call__(self, x):
# x: chan x time x height x width
assert x.shape[0] == 1
x = x[0]
if self.n_warmup is not None:
x = x[self.n_warmup:]
activity_list = []
for t in range(x.shape[0]):
activity_list.append(self.model_output(x[t]))
activity = torch.stack(activity_list) # time x neurons
# if self.n_warmup is not None:
# return F.pad(activity, (0, 0, self.n_warmup, 0))
# else:
return activity.detach().cpu()
def model_output(self, x):
raise NotImplementedError
class BWTModel(ImgModelBase):
def __init__(self, root, n_warmup=None):
super().__init__(n_warmup)
# Load model specific dependencies
from oct2py import octave
octave.addpath(os.path.join(root, "dependencies", "bwt"))
self.bwt = octave.bwt_v1_octave
def model_output(self, x):
h, w = x.shape
target_dim = 3 ** int(np.ceil(np.log(h)/np.log(3)))
h_pad = target_dim - h
w_pad = target_dim - w
xt = F.pad(x, (0, w_pad, 0, h_pad))
return torch.from_numpy(self.bwt(xt.numpy())).flatten()
class VGGModel(ImgModelBase):
def __init__(self, layer, n_warmup=None):
super().__init__(n_warmup)
self.vgg = models.vgg16(pretrained=True).cuda()
self.layer = layer
def model_output(self, x):
x = x.cuda()
if self.layer == '2.1':
self.layer = 6
elif self.layer == '2.2':
self.layer = 8
elif self.layer == '3.1':
self.layer = 11
elif self.layer == '3.2':
self.layer = 13
elif self.layer == '3.3':
self.layer = 15
x = x.unsqueeze(0).unsqueeze(0) # Add back batch and channel dim
x = x.repeat(1, 3, 1, 1) # Repeat the grayscale image three times for the rgb channels
with torch.no_grad():
for i in range(self.layer + 1):
x = self.vgg.features[i](x)
return x.flatten()

Computing file changes ...