https://github.com/GPflow/GPflow
Tip revision: 03af631bdbefbe67f85c81e74df10b860d05588f authored by Artem Artemev on 03 November 2019, 22:18:47 UTC
Resolve merge conflicts from develop
Resolve merge conflicts from develop
Tip revision: 03af631
train_mnist.py
# Copyright 2019 Mark van der Wilk, Vincent Dutordoir
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
train_mnist.py
A simple script to train a convolutional GP on MNIST. Training stats are
sent to TensorBoard in the ./mnist/ directory.
Usage examples (for using float64 and float32 respectively):
`python train_mnist.py`
`python train_mnist.py with float_type=float32 jitter_level=1e-4`
The latter should reach around 1.23% error after 120k iterations.
"""
import datetime
import tensorflow as tf
import os
from pathlib import Path
import numpy as np
from sacred import Experiment
from tensorflow.keras import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard
from tensorflow.keras.initializers import constant, truncated_normal
from tensorflow.keras.layers import (BatchNormalization, Conv2D, Dense,
Dropout, Flatten, MaxPool2D)
from tensorflow.keras.regularizers import l2
from sklearn.feature_extraction.image import extract_patches_2d
import gpflow
import gpflow.training.monitor as mon
NAME = "mnist"
ex = Experiment(NAME)
def calc_binary_error(model, Xs, Ys, batchsize=100):
Ns = len(Xs)
splits = Ns // batchsize
hits = []
for xs, ys in zip(np.array_split(Xs, splits), np.array_split(Ys, splits)):
p, _ = model.predict_y(xs)
acc = ((p > 0.5).astype('float') == ys)
hits.append(acc)
error = 1.0 - np.concatenate(hits, 0)
return np.sum(error) * 100.0 / len(error)
def calc_multiclass_error(model, Xs, Ys, batchsize=100):
Ns = len(Xs)
splits = Ns // batchsize
hits = []
for xs, ys in zip(np.array_split(Xs, splits), np.array_split(Ys, splits)):
p, _ = model.predict_y(xs)
acc = p.argmax(1) == ys[:, 0]
hits.append(acc)
error = 1.0 - np.concatenate(hits, 0)
return np.sum(error) * 100.0 / len(error)
def get_error_cb(model, Xs, Ys, error_func, full=False, Ns=500):
def error_cb(*args, **kwargs):
if full:
xs, ys = Xs, Ys
else:
xs, ys = Xs[:Ns], Ys[:Ns]
return error_func(model, xs, ys, batchsize=50)
return error_cb
def save_gpflow_model(filename, model) -> None:
gpflow.Saver().save(filename, model)
def get_dataset(dataset: str): # dataset = [mnist, mnist01, cifar]
assert dataset == "mnist" or dataset == "mnist-small-subset-for-test"
(X, Y), (Xs, Ys) = tf.keras.datasets.mnist.load_data()
X, Xs = [x.reshape(-1, 784) / 255.0 for x in [X, Xs]]
Y, Ys = [y.astype(int) for y in [Y, Ys]]
Y, Ys = [y.reshape(-1, 1) for y in [Y, Ys]]
if dataset == "mnist-small-subset-for-test":
X, Xs, Y, Ys = [x[:300, :] for x in [X, Xs, Y, Ys]]
return (X, Y), (Xs, Ys)
@ex.config
def config():
model_type = "convgp" # convgp | cnn
dataset = "mnist"
lr_cfg = {
"decay": "custom",
"decay_steps": 30000,
"lr": 1e-3
}
date = datetime.datetime.now().strftime('%b%d_%H:%M')
iterations = 120000
patch_shape = [5, 5]
batch_size = 128
# path to save results
basepath = "./"
num_inducing_points = 1000
base_kern = "RBF"
init_patches = "patches-unique" # 'patches', 'random'
restore = False
# print hz
hz = {
'slow': 1000,
'short': 50
}
float_type = "float64"
jitter_level = 1e-6
@ex.capture
def get_data(dataset, model_type):
(X, Y), (Xs, Ys) = get_dataset(dataset)
if model_type == "cnn":
if dataset == "mnist":
H, W = 28, 28
elif dataset == "cifar":
H, W = 32, 32
else:
raise NotImplementedError
X = X.reshape(-1, H * W, 1)
Xs = Xs.reshape(-1, H * W, 1)
return (X, Y), (Xs, Ys)
@ex.capture
def experiment_name(model_type, lr_cfg, num_inducing_points, batch_size, dataset,
base_kern, init_patches, patch_shape, date):
name = f"{model_type}_{date}"
if model_type == "cnn":
args = np.array([
name,
f"lr-{lr_cfg['lr']}",
f"batchsize-{batch_size}"])
else:
args = np.array([
name,
f"initpatches-{init_patches}",
f"kern-{base_kern}",
f"lr-{lr_cfg['lr']}",
f"lrdecay-{lr_cfg['decay']}",
f"nip-{num_inducing_points}",
f"batchsize-{batch_size}",
f"patch-{patch_shape[0]}"])
return "_".join(args.astype(str))
@ex.capture
def experiment_path(basepath, dataset):
experiment_dir = Path(basepath, dataset, experiment_name())
experiment_dir.mkdir(parents=True, exist_ok=True)
return str(experiment_dir)
#########
## ConvGP
#########
# Currently not used, but useful to have around.
@ex.capture
def restore_session(session, restore):
model_path = experiment_path()
if restore and os.path.isdir(model_path):
mon.restore_session(session, model_path)
print("Model restored")
@ex.capture
def get_likelihood(dataset):
if dataset == "mnist01":
return gpflow.likelihoods.Bernoulli()
return gpflow.likelihoods.SoftMax(10)
@ex.capture
def patch_initializer(X, M, patch_shape, init_patches):
if init_patches == "random":
return np.random.randn(M, np.prod(patch_shape))
elif init_patches == "patches-unique":
imh = int(X.shape[1] ** 0.5)
patches = np.array([extract_patches_2d(im.reshape(imh, imh), patch_shape) for im in X])
patches = np.concatenate(patches, axis=0)
patches = np.reshape(patches, [-1, np.prod(patch_shape)])
patches = np.unique(patches, axis=0)
# patches = np.reshape(patches, [-1, *patch_shape]) # (N * P) x w x h
idx = np.random.permutation(range(len(patches)))[:M] # M
return patches[idx, ...].reshape(M, np.prod(patch_shape)) # [M, w x h]
else:
raise NotImplementedError
@gpflow.defer_build()
@ex.capture
def convgp_setup_model(train_data, batch_size,
patch_shape, num_inducing_points):
X, Y = train_data
H = int(X.shape[1] ** .5)
likelihood = get_likelihood()
num_latent = likelihood.num_classes if hasattr(likelihood, 'num_classes') else 1
patches = patch_initializer(X[:400], num_inducing_points, patch_shape)
kern = gpflow.kernels.WeightedConvolutional(gpflow.kernels.SquaredExponential(np.prod(patch_shape)),
[H, H], patch_size=patch_shape)
feat = gpflow.features.InducingPatch(patches)
kern.basekern.variance = 25.0
kern.basekern.lengthscales = 1.2
model = gpflow.models.SVGP(X, Y, kern, likelihood, num_latent=num_latent, feat=feat, minibatch_size=batch_size,
name="gp_model")
model.q_mu = np.random.randn(*(model.q_mu.read_value().shape)).astype(gpflow.settings.float_type)
return model
@ex.capture
def convgp_monitor_tasks(train_data, model, optimizer, hz, dataset):
Xs, Ys = train_data
path = experiment_path()
fw = mon.LogdirWriter(path)
tasks = []
def lr(*args, **kwargs):
sess = model.enquire_session()
return sess.run(optimizer._optimizer._lr)
def periodic_short():
return mon.PeriodicIterationCondition(hz['short'])
def periodic_slow():
return mon.PeriodicIterationCondition(hz['slow'])
tasks += [
mon.ScalarFuncToTensorBoardTask(fw, lr, "lr")
.with_name('lr')
.with_condition(periodic_short())
.with_exit_condition(True)
.with_flush_immediately(True)]
tasks += [
mon.CheckpointTask(path)
.with_name('saver')
.with_condition(periodic_short())]
tasks += [
mon.ModelToTensorBoardTask(fw, model)
.with_name('model_tboard')
.with_condition(periodic_short())
.with_exit_condition(True)
.with_flush_immediately(True)]
tasks += [
mon.PrintTimingsTask().with_name('print')
.with_condition(periodic_short())
.with_exit_condition(True)]
error_func = calc_binary_error if dataset == "mnist01" \
else calc_multiclass_error
f1 = get_error_cb(model, Xs, Ys, error_func)
tasks += [
mon.ScalarFuncToTensorBoardTask(fw, f1, "error")
.with_name('error')
.with_condition(periodic_short())
.with_exit_condition(True)
.with_flush_immediately(True)]
f2 = get_error_cb(model, Xs, Ys, error_func, full=True)
tasks += [
mon.ScalarFuncToTensorBoardTask(fw, f2, "error_full")
.with_name('error_full')
.with_condition(periodic_slow())
.with_exit_condition(True)
.with_flush_immediately(True)]
print("# tasks:", len(tasks))
return tasks
@ex.capture
def convgp_setup_optimizer(model, global_step, lr_cfg):
if lr_cfg['decay'] == "custom":
print("Custom decaying lr")
lr = lr_cfg['lr'] * 1.0 / (1 + global_step // lr_cfg['decay_steps'] / 3)
else:
lr = lr_cfg['lr']
return gpflow.train.AdamOptimizer(lr)
@ex.capture
def convgp_fit(train_data, test_data, iterations, float_type, jitter_level):
custom_settings = gpflow.settings.get_settings()
custom_settings.dtypes.float_type = getattr(np, float_type)
custom_settings.numerics.jitter_level = jitter_level
gpflow.settings.push(custom_settings)
session = gpflow.get_default_session()
step = mon.create_global_step(session)
model = convgp_setup_model(train_data)
model.compile()
optimizer = convgp_setup_optimizer(model, step)
optimizer.minimize(model, maxiter=0)
monitor_tasks = convgp_monitor_tasks(train_data, model, optimizer)
monitor = mon.Monitor(monitor_tasks, session, step, print_summary=True)
restore_session(session)
print(session.run(optimizer.optimizer.variables()[:3]))
with monitor:
optimizer.minimize(model,
step_callback=monitor,
maxiter=iterations,
global_step=step)
convgp_finish(train_data, test_data, model)
def convgp_save(model):
filename = experiment_path() + f'/convgp.gpflow'
save_gpflow_model(filename, model)
print(f"Model saved at {filename}")
@ex.capture
def convgp_finish(train_data, test_data, model, dataset):
X, Y = train_data
Xs, Ys = test_data
error_func = calc_binary_error if dataset == "mnist01" else calc_multiclass_error
error_func = get_error_cb(model, Xs, Ys, error_func, full=True)
print(f"Error test: {error_func()}")
print(f"Error train: {error_func()}")
convgp_save(model)
######
## CNN
######
def cnn_monitor_callbacks():
path = experiment_path()
filename = path + '/cnn.{epoch:02d}-{val_acc:.2f}.h5'
cbs = []
cbs.append(ModelCheckpoint(filename, verbose=1, period=10))
cbs.append(TensorBoard(path))
return cbs
def cnn_setup_model():
def MaxPool():
return MaxPool2D(pool_size=(2, 2), strides=(2, 2))
def Conv(num_kernels):
return Conv2D(num_kernels, (5, 5), (1, 1),
activation='relu', padding='same',
kernel_initializer=truncated_normal(stddev=0.1))
def FullyConnected(num_outputs, activation=None):
return Dense(num_outputs,
activation=activation,
bias_initializer=constant(0.1),
kernel_initializer=truncated_normal(stddev=0.1))
nn = Sequential()
nn.add(Conv(32))
nn.add(MaxPool())
nn.add(Conv(64))
nn.add(MaxPool())
nn.add(Flatten())
nn.add(FullyConnected(1024, activation='relu'))
nn.add(Dropout(0.5))
nn.add(FullyConnected(10, activation='softmax'))
return nn
@ex.capture
def cnn_fit(train_data, test_data, batch_size, iterations):
x, y = train_data
model = cnn_setup_model()
iters_per_epoch = x.shape[0] // batch_size
epochs = iterations // iters_per_epoch
callbacks = cnn_monitor_callbacks()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['mse', 'accuracy'])
model.fit(x, y,
batch_size=batch_size,
epochs=epochs,
callbacks=callbacks,
validation_data=test_data)
xt, yt = test_data
test_metrics = model.evaluate(xt, yt, batch_size=batch_size)
print(f"Test metrics: {test_metrics}")
@ex.capture
@ex.automain
def main(model_type):
train_data, test_data = get_data()
if model_type == "cnn":
cnn_fit(train_data, test_data)
else:
convgp_fit(train_data, test_data)
return 0