https://github.com/stwisdom/urnn
Tip revision: 9f8b679c03683d0edd3f9a38d5a7cd0eef25a1c5 authored by Scott T Wisdom on 28 April 2017, 20:34:56 UTC
Update README.md
Update README.md
Tip revision: 9f8b679
timit_prediction.py
import sys
sys.setrecursionlimit(10000)
import cPickle
import gzip
import theano
import pdb
from fftconv import cufft, cuifft
import numpy as np
import theano.tensor as T
from theano.ifelse import ifelse
from models import *
from optimizations import *
import argparse, timeit, time
import os
import scipy
import scipy.io.wavfile
import scipy.fftpack as fft
import scipy.signal
import scipy.linalg
import librosa
from util import (stft_mc,iAugSTFT,wavwrite)
def wavread(wavfile):
fs,x=scipy.io.wavfile.read(wavfile) #x will be nsampl x nch
x=np.transpose(x).astype(np.float32) #convert x to float32, transpose to nch x nsampl
x=x/32768.0
return x
def iAugFFT(Xaug,axis=0):
F=Xaug.shape[axis]/2
X=np.take(Xaug,np.arange(0,F),axis=axis)+np.complex64(1j)*np.take(Xaug,np.arange(F,2*F),axis=axis)
X=np.concatenate((X.conj(), np.take(X,np.arange(F-2,0,-1),axis=axis)), axis=axis)
xr=fft.ifft(X,axis=axis).real
return xr
def load_wavfiles_names(path):
wavfiles=list()
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(('.wav')):
wavfile=os.path.join(root,file)
wavfiles.append(wavfile)
return wavfiles
def normalize_data(x,data_normalization,data_type,mask=None,n=None):
data_normalization=data_normalization.lower()
if mask is None:
mask=np.ones_like(x)
if n is None:
n=x.shape[2]
stats={}
if ('perutt' in data_normalization):
axes_mean=()
else:
axes_mean=(1)
if ('mean' in data_normalization):
# means of input data z
stats['mean']=np.mean( np.sum(x*mask[:,:,0:1],axis=0,keepdims=True)/np.float32(np.sum(mask[:,:,0:1],axis=0,keepdims=True)), axis=axes_mean)
x=x-(stats['mean']*mask[:,:,0:1])
if ('var' in data_normalization):
# std devs of input data z
if (data_type=='real'):
x_var=np.sum( (x**2)*mask[:,:,0:1],axis=0,keepdims=True)/np.float32(np.sum(mask[:,:,0:1],axis=0,keepdims=True))
elif (data_type=='complex'):
x_var=np.sum( (x[:,:,:n]**2+x[:,:,n:]**2)*mask[:,:,0:1],axis=0,keepdims=True)/np.float32(np.sum(mask[:,:,0:1],axis=0,keepdims=True))
stats['std']=np.sqrt(np.mean(x_var,axis=axes_mean))
if (data_type=='real'):
x=x*mask[:,:,0:1]/(np.float32(1e-7)+stats['std'])
elif (data_type=='complex'):
x=x*mask[:,:,0:1]/(np.float32(1e-7)+np.sqrt(2).astype(np.float32)*np.tile(stats['std'],(1,1,2)))
return x, stats
def generate_data(wavfiles,params_stft,prng,flag_unwrap_phase=True):
N=params_stft['N']
hop=params_stft['hop']
nch=params_stft['nch']
window=params_stft['window']
F=N/2+1
# initialize matrices to hold concatenated STFTs
X=np.zeros((nch*F,0)).astype(np.complex64)
Y=np.zeros((nch*F,0)).astype(np.complex64)
# initialize frame indices for individual files
fidx=np.zeros((len(wavfiles),2)).astype(np.int32)
ifidx=0
ifile=0
for wavfile in wavfiles:
print "Read file %d of %d total: %s" % (ifile+1,len(wavfiles),wavfile)
# read in reference output audio
y=wavread(wavfile)
Ycur=stft_mc(y,N,hop,window)
Ycur=Ycur[:,:,:nch] #restrict to desired number of channels
Ycur=np.transpose(Ycur,(0,2,1)) #is now F x nch x nfram
Ycur=np.reshape(Ycur,(nch*F,Ycur.shape[2]),order='F') #stack multiple channels in first dimension
# update frame indices for this file
nfram=Ycur.shape[1]
fidx[ifile,0]=ifidx
ifidx+=nfram
fidx[ifile,1]=ifidx
ifile+=1
if flag_unwrap_phase:
# remove window hop phases:
Yphase=np.float32(np.unwrap(np.angle(Ycur),axis=1))
frange=np.arange(0,F,dtype=np.float32)/N
trange=np.arange(0,nfram,dtype=np.float32)*hop
Yphase=Yphase-2*np.pi*np.outer(frange,trange)
Ycur=np.abs(Ycur)*np.exp(1j*Yphase)
# add Y to total data
Y=np.concatenate((Y,Ycur),axis=1)
Xaug=prng.randn(2*F,Y.shape[1])/np.sqrt(2) #unit variance circular complex Gaussians
Yaug=np.concatenate((np.real(Y),np.imag(Y)),axis=0)
return Xaug,Yaug,fidx
def generate_synth_data(n_seq,time_steps,sizes,prng,Winit='svd'):
n_input=sizes['n_input']
n_hidden=sizes['n_hidden']
Xaug=prng.randn(2*n_input,n_seq*time_steps).astype(np.float32)/np.sqrt(2) #unit variance circular complex Gaussians in real-composite form
if (Winit=='svd'):
W=prng.randn(n_hidden,n_hidden).astype(np.complex64)+1j*prng.randn(n_hidden,n_hidden).astype(np.complex64)
U, S, V = np.linalg.svd(W)
W = np.dot(U,V)
# convert W to real-composite form for right multiplication
# real-composite for right multiplication, g=h^T W, with h=x+jy and W=A+jB,
# is grc=hrc^T Wrc, with Wrc=[A^T, B^T; -B^T, A^T] and hrc=[x; y]
#
# real-composite for left multiplication, g=Wh, with h=x+jy and W=A+jB,
# is grc=Wrc hrc, with Wrc=[A, -B; B, A] and hrc=[x; y]
A=np.transpose(np.real(W))
B=np.transpose(np.imag(W))
Wr = np.concatenate( [ A, B], axis=1) #create [ A, B]
Wc = np.concatenate( [(-1)*B, A], axis=1) #create [-B, A]
Waug = np.concatenate( [Wr,Wc], axis=0) # create [A,B; -B, A]
elif (Winit=='adhoc'):
Wparams=initialize_unitary(n_hidden,'full',prng)
Waug = Wparams[0].get_value()
elif (Winit=='adhoc2x'):
Wparams1=initialize_unitary(n_hidden,'full',prng)
Waug1 = Wparams1[0]
Waug1np = Waug1.get_value()
Waug1np = Waug1np[:n_hidden,:] # only take first row of blocks to get correct augmented form after multiplication within numerical precision
Wparams2=initialize_unitary(n_hidden,'full',prng)
Waug2 = Wparams2[0]
Waug_row1=np.dot(Waug1np,Waug2.get_value())
Waug=np.concatenate([ Waug_row1, np.concatenate([-Waug_row1[:,n_hidden:],Waug_row1[:,:n_hidden]],axis=1) ],axis=0)
fidx0 = np.arange(0,n_seq*time_steps,time_steps)
fidx1 = np.arange(time_steps,n_seq*time_steps+time_steps,time_steps)
fidx = np.concatenate( [np.reshape(fidx0,(n_seq,1)), np.reshape(fidx1,(n_seq,1))] , axis=1)
return Xaug, Waug, fidx
def main(n_iter, n_batch, n_hidden, learning_rate, savefile, model, input_type, out_every_t, loss_function, fold, scene, n_reflections=None,flag_telescope=True,nch=1,flag_unwrap_phase=True,indir="audio_8khz",outdir=None,dataset="timit",initfile=None,flag_feed_forward=True,flag_generator=False,downsample_train=1,downsample_test=1,time_steps=None,n_Givens=None,prng_seed_Givens=52016,num_allowed_test_inc=10,iters_per_validCheck=20,flag_useFullW=False,flag_onlyOptimW=True,lam=np.float32(0.0),Vnorm=np.float32(0.0),Unorm=np.float32(0.0),n_layers=1,num_pred_steps=0,hidden_bias_mean=0.1,data_transform='',bwe_frac=np.float32(1.0),data_normalization='none',offset_eval=None,olap=50,window=None,flag_noDiv=0,flag_noComplexConstraint=0,Winit='svd',seed=1234,optim_alg="rmsprop",n_utt_eval_spec=-1):
if offset_eval<0:
offset_eval=None
cost_weight=None
cost_transform=None
# --- Set data params ----------------
if (dataset=='timit16'):
N=512 #32 ms at fs=16kHz
else:
N=256 #32 ms at fs=8kHz
hop=np.round(np.float32(N)*np.float32(100.0-olap)/100.0).astype(np.int)
if (window=='hann'):
window=scipy.signal.hann(N,sym=False)
elif (window=='sqrt_hann'):
window=np.sqrt(scipy.signal.hann(N,sym=False))
else:
window=None
params_stft={'N': N, 'hop': hop, 'nch': nch, 'window': window} #STFT parameters
F=N/2+1
n_input =F #we're stacking multiple channels on top of each other
n_output=n_input #because we are building an autoencoder
ds_train = 1 #downsampling factor for training
ds_test = 1 #downsampling factor for test
#set paths:
if (dataset=='timit') or (dataset=='timit_trainNoSA_dev_coreTest'):
path_train=''.join(["/data1/timit/TIMIT_8khz/TRAIN"])
path_test =''.join(["/data1/timit/TIMIT_8khz/TEST"])
elif (dataset=='timit16'):
path_train=''.join(["/data1/timit/TIMIT_16khz/TRAIN"])
path_test =''.join(["/data1/timit/TIMIT_16khz/TEST"])
elif (dataset=='synthgen'):
path_train=None
path_test=None
else:
raise ValueError("dataset must be synthgen or timit")
if not (dataset == 'synthgen'):
#load and downsample wavfiles list for training
wavfiles_train_all=load_wavfiles_names(path_train)
wavfiles_train =wavfiles_train_all[::ds_train]
n_train=len(wavfiles_train)
#load and downsample wavfiles list for test
wavfiles_test_all =load_wavfiles_names(path_test)
wavfiles_test =wavfiles_test_all[::ds_test]
else:
n_train=int(2e4)
n_test =int(2e3)
n_input=n_hidden
n_output=n_hidden
sizes ={'n_input': n_input, 'n_hidden': n_hidden, 'n_output': n_output}
num_batches = int(n_train / n_batch)
# --- Create data --------------------
# set up random number generators for repeatable results
train_prng = np.random.RandomState(5678)
test_prng = np.random.RandomState(42)
# generate and/or load the data
savefile_timit_data=None
if (dataset=='timit'):
savefile_timit_data='timit_data'
elif (dataset=='timit16'):
savefile_timit_data='/data1/swisdom/timit16_data'
elif (dataset=='timit_trainNoSA_dev_coreTest'):
savefile_timit_data='timit_data_trainNoSA_dev_coreTest'
if ('timit' in dataset) and (os.path.isfile(savefile_timit_data) or os.path.isfile(savefile_timit_data+'_train_xdata_stack')):
# we're using TIMIT and a save file for TIMIT exists, so load up the data
print "Save file %s for TIMIT data exists, loading it from the hard drive..." % savefile_timit_data
if (dataset=='timit') or (dataset=='timit_trainNoSA_dev_coreTest'):
L=cPickle.load(file(savefile_timit_data,'r'))
print "Loaded TIMIT data"
train_z_stack=L['train_z_stack']
train_xdata_stack=L['train_xdata_stack']
fidx_train=L['fidx_train']
test_z_stack=L['test_z_stack']
test_xdata_stack=L['test_xdata_stack']
fidx_test=L['fidx_test']
elif (dataset=='timit16'):
for key in ['train_z_stack','train_xdata_stack','fidx_train','test_z_stack','test_xdata_stack','fidx_test']:
print "Broken exec statement"
#exec("%s=np.load(file(savefile_timit_data+'_'+key,'rb'))" % key)
n_train=fidx_train.shape[0]
num_batches = int(n_train / n_batch)
elif not (dataset == 'synthgen'):
# we aren't using the synthgen dataset, or we aren't using the timit dataset,
# or the savefile for timit data doesn't exist, so load data using the lists
# of wavfiles and generate associated random data
if (dataset=='timit_trainNoSA_dev_coreTest'):
# adjust wavfiles lists to exclude SA utterances from train
# and make the test set concatenated TIMIT dev set and
# core test set
wavfiles_train = [x for x in wavfiles_train if (not ('sa' in x.lower()))]
wavfiles_test = [x for x in wavfiles_test if (not ('sa' in x.lower()))]
speakers_dev = [line.rstrip('\n') for line in open('timit_dev_spk.list')]
wavfiles_dev = [x for x in wavfiles_test if any(speaker in x.lower() for speaker in speakers_dev)]
speakers_coreTest = [line.rstrip('\n') for line in open('timit_test_spk.list')]
wavfiles_coreTest = [x for x in wavfiles_test if any(speaker in x.lower() for speaker in speakers_coreTest)]
wavfiles_extraTest = [x for x in wavfiles_test if (not (x in wavfiles_dev+wavfiles_coreTest))]
wavfiles_test=wavfiles_dev+wavfiles_coreTest+wavfiles_extraTest
train_z_stack, train_xdata_stack, fidx_train = generate_data(wavfiles_train,
params_stft,
train_prng,
flag_unwrap_phase)
test_z_stack, test_xdata_stack, fidx_test = generate_data(wavfiles_test,
params_stft,
test_prng,
flag_unwrap_phase)
# z are 2*nch*F x \sum_utt n_fram(utt), xdata are 2*nsrc*nch*F x \sum_utt n_fram(utt)
# if we're doing TIMIT and the save file doesn't exist, write it out:
if ( 'timit' in dataset ) and not (os.path.isfile(savefile_timit_data) or os.path.isfile(savefile_timit_data+'_train_xdata_stack')):
print "Saving TIMIT data to file %s" % savefile_timit_data
# we have read in and generated z's for TIMIT data; save it off
save_vals_timit_data={'train_z_stack': train_z_stack,
'train_xdata_stack': train_xdata_stack,
'fidx_train': fidx_train,
'test_z_stack': test_z_stack,
'test_xdata_stack': test_xdata_stack,
'fidx_test': fidx_test}
#if (dataset=='timit'):
if not (dataset=='timit16'):
cPickle.dump(save_vals_timit_data, file(savefile_timit_data, 'wb'), cPickle.HIGHEST_PROTOCOL)
#elif (dataset=='timit16'):
else:
for key in save_vals_timit_data.keys():
np.save(file(savefile_timit_data+'_'+key,'wb'),save_vals_timit_data[key])
"""
else:
print "Unknown timit dataset name %s" % dataset
return
"""
else:
# we're using the synthgen dataset, so use a different function, generate_synth_data, to create data
train_z_stack, synth_Waug, fidx_train = generate_synth_data(n_train,time_steps,sizes,train_prng,Winit)
#train_xdata_stack = np.zeros_like(train_z_stack) #set xdata to 0, since we'll generate these later
train_xdata_stack = np.zeros((2*n_output,n_train*time_steps))
test_z_stack, extra_Waug, fidx_test = generate_synth_data(n_test, time_steps,sizes,test_prng,Winit)
#test_xdata_stack = np.zeros_like(test_z_stack)
test_xdata_stack = np.zeros((2*n_output,n_test*time_steps))
# if input z is real-valued,
# train_z_stack should be of dimension n_framMax x n_input x n_utt
# if input z is complex-valued,
# train_z_stack should be of dimension n_framMax x 2*n_input x n_utt
# if xdata is real-valued,
# train_xdata_stack should be of dimension n_framMax x n_output x n_utt
# if xdata is complex-valued,
# train_xdata_stack should be of dimension n_framMax x 2*n_output x n_utt
# check if we're doing an autoencoder (flag_generator=False) or a generator (flag_generator=True)
if not flag_generator:
# since we're doing an autoencoder, set train_z equal to train_xdata
train_z_stack=np.copy(train_xdata_stack)
test_z_stack=np.copy(test_xdata_stack)
#tweaks to ensure dynamical system output is relatively stable (determined by playing with parameters in Matlab):
## scale random data down for timit
#if (dataset == 'timit'):
# train_z_stack = train_z_stack*np.sqrt(1.0)
# test_z_stack = test_z_stack*np.sqrt(1.0)
# use a negative hidden bias mean
if (dataset == 'synthgen'):
hidden_bias_mean=-0.1
#else:
# hidden_bias_mean=0.0
# create padded inputs and outputs for train:
lens_train=fidx_train[:,1]-fidx_train[:,0]
n_framMax_train=np.max(lens_train)
n_utt_train=len(lens_train)
train_z = np.zeros((n_framMax_train,2*n_input, n_utt_train)).astype(np.float32)
train_xdata = np.zeros((n_framMax_train,2*n_output,n_utt_train)).astype(np.float32)
for iutt in range(n_utt_train):
train_z[:lens_train[iutt],:,iutt]=np.transpose(train_z_stack[:,fidx_train[iutt,0]:fidx_train[iutt,1]])
train_xdata[:lens_train[iutt],:,iutt]=np.transpose(train_xdata_stack[:,fidx_train[iutt,0]:fidx_train[iutt,1]])
# train_z is in augmented form and is now of dimensions n_framMax_train x 2*n_input x n_utt_train
# train_xdata is in augmented form and is now of dimensions n_framMax_train x 2*n_output x n_utt_train
# create padded inputs and outputs for test:
lens_test=fidx_test[:,1]-fidx_test[:,0]
n_framMax_test=np.max(lens_test)
n_utt_test=len(lens_test)
test_z = np.zeros((n_framMax_test,2*n_input, n_utt_test)).astype(np.float32)
test_xdata = np.zeros((n_framMax_test,2*n_output,n_utt_test)).astype(np.float32)
for iutt in range(n_utt_test):
test_z[:lens_test[iutt],:,iutt]=np.transpose(test_z_stack[:,fidx_test[iutt,0]:fidx_test[iutt,1]])
test_xdata[:lens_test[iutt],:,iutt]=np.transpose(test_xdata_stack[:,fidx_test[iutt,0]:fidx_test[iutt,1]])
# test_z is in augmented form and is now of dimensions n_framMax_test x 2*n_input x n_utt_test
# test_xdata is in augmented form and is now of dimensions n_framMax_test x 2*n_output x n_utt_test
# to get scan to work properly, transpose x and y to be of size n_framMax x n_utt x n_<input,output>
train_z=np.transpose(train_z,[0,2,1])
train_xdata=np.transpose(train_xdata,[0,2,1])
test_z =np.transpose(test_z,[0,2,1])
test_xdata =np.transpose(test_xdata,[0,2,1])
output_type='complex' #assume complex-valued data
if (data_transform=='logmag'):
print "Using log-magnitude transform on input and output data"
print ""
input_type='real'
output_type='real'
train_z=10.0*np.log10(1e-5 + train_z[:,:,:n_input]**2 + train_z[:,:,n_input:]**2)
train_xdata=10.0*np.log10(1e-5 + train_xdata[:,:,:n_output]**2 + train_xdata[:,:,n_output:]**2)
test_z=10.0*np.log10(1e-5 + test_z[:,:,:n_input]**2 + test_z[:,:,n_input:]**2)
test_xdata=10.0*np.log10(1e-5 + test_xdata[:,:,:n_output]**2 + test_xdata[:,:,n_output:]**2)
elif (data_transform=='logmag_phasePrediction'):
print "Using log-magnitude transform on input data, using linear complex for output, modifying cost function for phase prediction"
print ""
input_type='real'
output_type='real'
train_z=10.0*np.log10(1e-5 + train_z[:,:,:n_input]**2 + train_z[:,:,n_input:]**2)
test_z=10.0*np.log10(1e-5 + test_z[:,:,:n_input]**2 + test_z[:,:,n_input:]**2)
cost_transform='magTimesPhase'
loss_function="none_in_scan"
elif (data_transform=='time_domain_windowed'):
print "Using windowed time-domain frames for input and output data"
print ""
input_type='real'
output_type='real'
n_input=N
n_output=N
start_time=time.time()
train_z=iAugFFT(train_z,axis=2)
train_xdata=np.copy(train_z)
#train_xdata=iAugFFT(train_xdata,axis=2)
test_z=iAugFFT(test_z,axis=2)
test_xdata=np.copy(test_z)
#test_xdata=iAugFFT(test_xdata,axis=2)
elapsed_time = time.time() - start_time
print "Elapsed time to compute IFFTs: %f" % elapsed_time
print ""
if (bwe_frac<1.0):
if not (n_input==n_output):
print "Error: bwe_frac is less than 1.0, but n_input and n_output are not equal! Exiting..."
return
bwe_n_output = np.round(bwe_frac*n_output)
bwe_n_input = n_output - bwe_n_output
# grab lower indices of z as input
train_z=np.concatenate( [train_z[:,:,:bwe_n_input],train_z[:,:,n_input:n_input+bwe_n_input]],axis=2)
test_z=np.concatenate( [test_z[:,:,:bwe_n_input],test_z[:,:,n_input:n_input+bwe_n_input]],axis=2)
# grab upper indices of xdata as targets
train_xdata=np.concatenate( [train_xdata[:,:,bwe_n_input:n_output],train_xdata[:,:,n_output+bwe_n_input:]], axis=2)
test_xdata=np.concatenate( [test_xdata[:,:,bwe_n_input:n_output],test_xdata[:,:,n_output+bwe_n_input:]], axis=2)
n_input=bwe_n_input
n_output=bwe_n_output
if (num_pred_steps>0):
print "Predicting reference data %d steps ahead" % num_pred_steps
print ""
train_xdata[:n_framMax_train-num_pred_steps,:,:]=train_xdata[num_pred_steps:,:,:]
test_xdata[:n_framMax_test-num_pred_steps,:,:]=test_xdata[num_pred_steps:,:,:]
# apply downsampling factors to train and test data, if the factors are greater than 1
if (downsample_train>1):
train_z=train_z[:,0:n_utt_train:downsample_train,:]
train_xdata=train_xdata[:,0:n_utt_train:downsample_train,:]
lens_train=lens_train[0:n_utt_train:downsample_train]
num_batches=num_batches/downsample_train
n_train=n_train/downsample_train
n_utt_train=n_utt_train/downsample_train
if offset_eval is not None:
if (downsample_test==1):
# build eval data:
if (n_utt_eval_spec>0):
n_utt_eval=n_utt_eval_spec
eval_z=test_z[:,offset_eval:offset_eval+n_utt_eval,:]
eval_xdata=test_xdata[:,offset_eval:offset_eval+n_utt_eval,:]
lens_eval=lens_test[offset_eval:offset_eval+n_utt_eval]
else:
eval_z=test_z[:,offset_eval:n_utt_test,:]
eval_xdata=test_xdata[:,offset_eval:n_utt_test,:]
lens_eval=lens_test[offset_eval:n_utt_test]
n_utt_eval=n_utt_test-offset_eval
# clip test data:
test_z=test_z[:,:offset_eval,:]
test_xdata=test_xdata[:,:offset_eval,:]
lens_test=lens_test[:offset_eval]
n_utt_test=offset_eval
else:
eval_z=test_z[:,offset_eval:n_utt_test:downsample_test,:]
eval_xdata=test_xdata[:,offset_eval:n_utt_test:downsample_test,:]
lens_eval=lens_test[0:n_utt_test:downsample_test]
n_utt_eval=n_utt_test/downsample_test
if (downsample_test>1):
test_z=test_z[:,0:n_utt_test:downsample_test,:]
test_xdata=test_xdata[:,0:n_utt_test:downsample_test,:]
lens_test=lens_test[0:n_utt_test:downsample_test]
n_utt_test=n_utt_test/downsample_test
# set data masks, if the data sequences have unequal length
if ('timit' in dataset):
flag_use_mask=True
else:
flag_use_mask=False
if flag_use_mask:
train_xdata_mask=np.zeros((n_framMax_train,n_utt_train,1),dtype=np.int8)
for ii in xrange(lens_train.shape[0]):
train_xdata_mask[0:lens_train[ii],ii,:]=1
if offset_eval is not None:
eval_xdata_mask=np.zeros((n_framMax_test,n_utt_eval,1),dtype=np.int8)
for ii in xrange(lens_eval.shape[0]):
eval_xdata_mask[0:lens_eval[ii],ii,:]=1
test_xdata_mask=np.zeros((n_framMax_test,n_utt_test,1),dtype=np.int8)
for ii in xrange(lens_test.shape[0]):
test_xdata_mask[0:lens_test[ii],ii,:]=1
# apply normalization to data, if specified
print "Applying normalization of %s to data..." % data_normalization
print ""
stats={}
#normalize train data
if flag_use_mask:
train_mask=train_xdata_mask
else:
train_mask=None
train_z, stats['train_z_stats']=normalize_data(train_z,data_normalization,input_type,mask=train_mask,n=n_input)
train_xdata, stats['train_xdata_stats']=normalize_data(train_xdata,data_normalization,output_type,mask=train_mask,n=n_output)
#normalize test data
if flag_use_mask:
test_mask=test_xdata_mask
else:
test_mask=None
test_z, stats['test_z_stats']=normalize_data(test_z,data_normalization,input_type,mask=test_mask,n=n_input)
test_xdata, stats['test_xdata_stats']=normalize_data(test_xdata,data_normalization,output_type,mask=test_mask,n=n_output)
if offset_eval is not None:
#normalize eval data
if flag_use_mask:
eval_mask=eval_xdata_mask
else:
eval_mask=None
eval_z, stats['eval_z_stats']=normalize_data(eval_z,data_normalization,input_type,mask=eval_mask,n=n_input)
eval_xdata, stats['eval_xdata_stats']=normalize_data(eval_xdata,data_normalization,output_type,mask=eval_mask,n=n_output)
# --- Create theano graph and compute gradients ----------------------
gradient_clipping = np.float32(1)
if (model == 'LSTM'):
n_input_LSTM=n_input
n_output_LSTM=n_output
if (input_type=='complex'):
n_input_LSTM=2*n_input
if (output_type=='complex'):
n_output_LSTM=2*n_output
inputs, parameters, costs = LSTM(n_input_LSTM, n_hidden, n_output_LSTM, input_type=input_type,out_every_t=out_every_t, loss_function=loss_function,flag_use_mask=flag_use_mask,flag_return_lin_output=True,flag_return_hidden_states=True,cost_weight=cost_weight,cost_transform=cost_transform,seed=seed)
gradients = T.grad(costs[0], parameters)
gradients = [T.clip(g, -gradient_clipping, gradient_clipping) for g in gradients]
elif (model == 'complex_RNN'):
if flag_useFullW:
Wimpl='full'
else:
if (n_Givens is None) or (n_Givens < 1):
Wimpl='adhoc'
else:
Wimpl='givens'
# build computational graph for train and test
inputs, parameters, costs = complex_RNN(n_input, n_hidden, n_output, input_type=input_type,out_every_t=out_every_t, loss_function=loss_function,output_type=output_type,flag_feed_forward=flag_feed_forward,flag_return_lin_output=True,flag_use_mask=flag_use_mask,hidden_bias_mean=hidden_bias_mean,Wimpl=Wimpl,prng_Givens=np.random.RandomState(prng_seed_Givens),lam=lam,Vnorm=Vnorm,Unorm=Unorm,flag_return_hidden_states=True,n_layers=n_layers,cost_weight=cost_weight,cost_transform=cost_transform,flag_noComplexConstraint=flag_noComplexConstraint,seed=seed)
idx_project=None
if (dataset == 'synthgen'):
if flag_onlyOptimW:
# don't optimize V, U, or out_bias (elements 0, 1, and 4/3 of parameters)
if (Wimpl=='adhoc'):
# reflection and theta
parameters_optimize=[parameters[3],parameters[5]]
elif (Wimpl=='full'):
parameters_optimize=[parameters[5]]
idx_project=[0]
else:
parameters_optimize=parameters
if (Wimpl=='full'):
# since we're using a full W matrix, indicate its index in the
# parameters_optimize list to make sure we use Steifel manifold
# optimization on it:
idx_project=[5]
gradients = T.grad(costs[0], parameters_optimize)
# build computational graph for generating train and test data
inputs_synth, parameters_synth, costs_synth = complex_RNN(n_input, n_hidden, n_output, input_type=input_type,out_every_t=out_every_t, loss_function='none_in_scan',output_type='complex',flag_feed_forward=flag_feed_forward,flag_return_lin_output=True,flag_use_mask=flag_use_mask,hidden_bias_mean=-0.1,Wimpl='full',lam=lam)
elif (( 'timit' in dataset ) and flag_generator):
if (n_layers==1):
print "Dataset is timit and we are running a generator with 1 layer, so we'll only optimize V, b, W, and h_0, and use initialization for U and c."
print ""
if (Wimpl=='adhoc'):
# only optimize V, hidden_bias, W parameters, h_0
parameters_optimize=[parameters[0],parameters[2],parameters[3],parameters[5],parameters[6]]+parameters[7:]
elif (Wimpl=='full'):
# only optimize V, hidden_bias, W parameters, h_0
parameters_optimize=[parameters[0],parameters[2],parameters[4],parameters[5]]+parameters[6:]
idx_project=[3]
gradients = T.grad(costs[0], parameters_optimize)
else:
if (Wimpl=='full'):
idx_project=[5]
gradients = T.grad(costs[0], parameters)
elif (model == 'IRNN'):
inputs, parameters, costs = IRNN(n_input, n_hidden, n_output, input_type=input_type,
out_every_t=out_every_t, loss_function=loss_function)
gradients = T.grad(costs[0], parameters)
gradients = [T.clip(g, -gradient_clipping, gradient_clipping) for g in gradients]
elif (model == 'RNN'):
inputs, parameters, costs = tanhRNN(n_input, n_hidden, n_output, input_type=input_type,
out_every_t=out_every_t, loss_function=loss_function)
gradients = T.grad(costs[0], parameters)
gradients = [T.clip(g, -gradient_clipping, gradient_clipping) for g in gradients]
else:
print "Unsuported model:", model
return
# allocate shared theano variables to hold train and test data
s_train_z = theano.shared(train_z,borrow=True)
if (dataset=='synthgen'):
s_train_xdata = theano.shared(np.zeros((time_steps,1,1)).astype(np.float32))
else:
s_train_xdata = theano.shared(train_xdata,borrow=True)
s_test_z = theano.shared(test_z,borrow=True)
if (dataset=='synthgen'):
s_test_xdata = theano.shared(np.zeros((time_steps,1,1)).astype(np.float32))
else:
s_test_xdata = theano.shared(test_xdata,borrow=True)
if offset_eval is not None:
s_eval_z = theano.shared(eval_z,borrow=True)
s_eval_xdata = theano.shared(eval_xdata,borrow=True)
if (dataset=='synthgen'):
s_synth_Waug = theano.shared(synth_Waug)
if flag_use_mask:
s_train_xdata_mask = theano.shared(train_xdata_mask,borrow=True)
s_test_xdata_mask = theano.shared(test_xdata_mask,borrow=True)
if offset_eval is not None:
s_eval_xdata_mask = theano.shared(eval_xdata_mask,borrow=True)
# --- Compile theano functions --------------------------------------------------
index = T.iscalar('i')
if (dataset == 'synthgen') or ( ('timit' in dataset) and flag_generator):
updates, rmsprop = rms_prop(learning_rate, parameters_optimize, gradients,idx_project)
else:
idx_project=None #assume we aren't doing projected gradient on any parameters
if flag_useFullW:
idx_project=[5]
updates, rmsprop = rms_prop(learning_rate, parameters, gradients,idx_project)
if (optim_alg=='sgd'):
updates = gradient_descent(learning_rate, parameters, gradients)
rmsprop = []
if (dataset == 'synthgen'):
# run theano functions to generate train and test data from random inputs
Vaug = np.zeros((n_input,2*n_hidden),dtype=np.float32)
Vaug[:n_input,:n_hidden] = np.eye(n_hidden)
#Vaug = test_prng.randn(n_input,2*n_hidden).astype(np.float32)
V_synth = theano.shared(Vaug)
Uaug = np.zeros((2*n_hidden,n_output),dtype=np.float32)
Uaug[:n_hidden,:n_output] = np.eye(n_hidden)
U_synth = theano.shared(Uaug)
h_0_synth = theano.shared(np.zeros((1,2*n_hidden),dtype=np.float32))
givens_synth_train = {inputs_synth[0] : s_train_z,
inputs_synth[1] : s_train_xdata,
parameters_synth[0] : V_synth,
parameters_synth[1] : U_synth,
parameters_synth[4] : h_0_synth,
parameters_synth[5] : s_synth_Waug}
synth_train = theano.function([], costs_synth[2], givens=givens_synth_train)
givens_synth_test = {inputs_synth[0] : s_test_z,
inputs_synth[1] : s_test_xdata,
parameters_synth[0] : V_synth,
parameters_synth[1] : U_synth,
parameters_synth[4] : h_0_synth,
parameters_synth[5] : s_synth_Waug}
synth_test = theano.function([], costs_synth[2], givens=givens_synth_test)
if offset_eval is not None:
givens_synth_eval = {inputs_synth[0] : s_eval_z,
inputs_synth[1] : s_eval_xdata,
parameters_synth[0] : V_synth,
parameters_synth[1] : U_synth,
parameters_synth[4] : h_0_synth,
parameters_synth[5] : s_synth_Waug}
synth_eval = theano.function([], costs_synth[2], givens=givens_synth_eval)
# synthesize outputs for train and test
print "Generating outputs for train set"
print ""
train_y_synth = synth_train()
train_xdata = train_y_synth
s_train_xdata=theano.shared(train_y_synth,borrow=True)
print "Generating outputs for test set"
print ""
test_y_synth = synth_test()
test_xdata = test_y_synth
s_test_xdata=theano.shared(test_y_synth,borrow=True)
if offset_eval is not None:
print "Generating outputs for eval set"
print ""
eval_y_synth = synth_eval()
eval_xdata = eval_y_synth
s_eval_xdata=theano.shared(eval_y_synth,borrow=True)
# set up train and test functions for training
if flag_use_mask:
givens = {inputs[0] : s_train_z[:, n_batch * index : n_batch * (index + 1), :],
inputs[1] : s_train_xdata[:, n_batch * index : n_batch * (index + 1), :],
inputs[2] : s_train_xdata_mask[:, n_batch * index : n_batch * (index + 1), :]}
givens_test = {inputs[0] : s_test_z,
inputs[1] : s_test_xdata,
inputs[2] : s_test_xdata_mask}
if offset_eval is not None:
givens_eval = {inputs[0] : s_eval_z,
inputs[1] : s_eval_xdata,
inputs[2] : s_eval_xdata_mask}
else:
givens = {inputs[0] : s_train_z[:, n_batch * index : n_batch * (index + 1), :],
inputs[1] : s_train_xdata[:, n_batch * index : n_batch * (index + 1), :]}
givens_test = {inputs[0] : s_test_z,
inputs[1] : s_test_xdata}
if offset_eval is not None:
givens_eval = {inputs[0] : s_eval_z,
inputs[1] : s_eval_xdata}
# load parameters from the specified initfile, if it exists
if initfile is not None and os.path.isfile(initfile):
print "Using file %s to initialize parameters" % initfile
L=cPickle.load(file(initfile,'r'))
best_params_load=L['best_params']
V=theano.shared(best_params_load[0])
U=theano.shared(best_params_load[1])
hidden_bias=theano.shared(best_params_load[2])
if (model=='LSTM'):
for iparam in range(len(best_params_load)):
dupdate={parameters[iparam] : theano.shared(best_params_load[iparam])}
givens.update(dupdate)
givens_test.update(dupdate)
if offset_eval is not None:
givens_eval.update(dupdate)
elif (Wimpl=='adhoc'):
reflection=theano.shared(best_params_load[3])
out_bias=theano.shared(best_params_load[4])
theta=theano.shared(best_params_load[5])
h_0=theano.shared(best_params_load[6])
if ( ('timit' in dataset) and flag_generator):
if (n_layers==1):
print "Dataset is timit and we are running a generator with 1 layer, so initialize U and c from initfile."
print ""
# only use U and out_bias for initialization
givens.update({parameters[1] : U,
parameters[4] : out_bias})
givens_test.update({parameters[1] : U,
parameters[4] : out_bias})
if offset_eval is not None:
givens_eval.update({parameters[1] : U,
parameters[4] : out_bias})
else:
givens_test.update({parameters[0] : V,
parameters[1] : U,
parameters[2] : hidden_bias,
parameters[3] : reflection,
parameters[4] : out_bias,
parameters[5] : theta,
parameters[6] : h_0})
if offset_eval is not None:
givens_eval.update({parameters[0] : V,
parameters[1] : U,
parameters[2] : hidden_bias,
parameters[3] : reflection,
parameters[4] : out_bias,
parameters[5] : theta,
parameters[6] : h_0})
elif (Wimpl=='full'):
out_bias=theano.shared(best_params_load[3])
h_0=theano.shared(best_params_load[4])
Waug=theano.shared(best_params_load[5])
if (('timit' in dataset) and flag_generator):
if (n_layers==1):
print "Dataset is timit and we are running a generator with 1 layer, so initialize U and c from initfile."
print ""
# only use U and out_bias for initialization
givens.update({parameters[1] : U,
parameters[3] : out_bias})
givens_test.update({parameters[1] : U,
parameters[3] : out_bias})
if offset_eval is not None:
givens_eval.update({parameters[1] : U,
parameters[3] : out_bias})
else:
givens_test.update({parameters[0] : V,
parameters[1] : U,
parameters[2] : hidden_bias,
parameters[3] : out_bias,
parameters[4] : h_0,
parameters[5] : Waug})
if offset_eval is not None:
givens_eval.update({parameters[0] : V,
parameters[1] : U,
parameters[2] : hidden_bias,
parameters[3] : out_bias,
parameters[4] : h_0,
parameters[5] : Waug})
if (dataset == 'synthgen') and flag_onlyOptimW:
# we are only optimizing W, so use some ground-truth parameters from
# the synth networks
givens[parameters[0]] = Vaug
givens[parameters[1]] = Uaug
givens[parameters[2]] = theano.shared(parameters_synth[2].get_value()) # hidden_bias
# out_bias
if (Wimpl == 'adhoc'):
givens[parameters[4]] = theano.shared(np.zeros((2*n_output,), dtype=theano.config.floatX))
elif (Wimpl == 'givens') or (Wimpl == 'full'):
givens[parameters[3]] = theano.shared(np.zeros((2*n_output,), dtype=theano.config.floatX))
# h_0
if (Wimpl == 'adhoc'):
givens[parameters[6]] = h_0_synth
elif (Wimpl == 'givens') or (Wimpl == 'full'):
givens[parameters[4]] = h_0_synth
givens_test[parameters[0]] = Vaug
givens_test[parameters[1]] = Uaug
givens_test[parameters[2]] = theano.shared(parameters_synth[2].get_value()) # hidden_bias
# out_bias
if (Wimpl == 'adhoc'):
givens_test[parameters[4]] = theano.shared(np.zeros((2*n_output,), dtype=theano.config.floatX))
elif (Wimpl == 'givens') or (Wimpl == 'full'):
givens_test[parameters[3]] = theano.shared(np.zeros((2*n_output,), dtype=theano.config.floatX))
# h_0
if (Wimpl == 'adhoc'):
givens_test[parameters[6]] = h_0_synth
elif (Wimpl == 'givens') or (Wimpl == 'full'):
givens_test[parameters[4]] = h_0_synth
if offset_eval is not None:
givens_eval[parameters[0]] = Vaug
givens_eval[parameters[1]] = Uaug
givens_eval[parameters[2]] = theano.shared(parameters_synth[2].get_value()) # hidden_bias
# out_bias
if (Wimpl == 'adhoc'):
givens_eval[parameters[4]] = theano.shared(np.zeros((2*n_output,), dtype=theano.config.floatX))
elif (Wimpl == 'givens') or (Wimpl == 'full'):
givens_eval[parameters[3]] = theano.shared(np.zeros((2*n_output,), dtype=theano.config.floatX))
# h_0
if (Wimpl == 'adhoc'):
givens_eval[parameters[6]] = h_0_synth
elif (Wimpl == 'givens') or (Wimpl == 'full'):
givens_eval[parameters[4]] = h_0_synth
train = theano.function([index], [costs[0],costs[1]], givens=givens, updates=updates)
test = theano.function([], [costs[0], costs[1], costs[2], costs[3], costs[4], costs[5]], givens=givens_test)
if offset_eval is not None:
evalf = theano.function([], [costs[0], costs[1], costs[2], costs[3], costs[4], costs[5]], givens=givens_eval)
# --- Training Loop ---------------------------------------------------------------
train_loss = []
train_ref = []
if (loss_function=='MSEplusL1'):
train_mse = []
test_mse = []
train_time= []
test_loss = []
test_ref = []
test_time = []
best_params = [p.get_value() for p in parameters]
best_rms = [r.get_value() for r in rmsprop]
best_test_loss = 1e10
#num_allowed_test_inc=10
num_test_inc=0
shuffle_rng=np.random.RandomState(314)
data_xdata = s_train_xdata.get_value()
for i in xrange(n_iter):
if (i % num_batches == 0):
# reshuffle batch indices
inds = shuffle_rng.permutation(n_train)
data_z = s_train_z.get_value()
s_train_z.set_value(data_z[:,inds,:])
data_xdata = s_train_xdata.get_value()
s_train_xdata.set_value(data_xdata[:,inds,:])
if flag_use_mask:
data_xdata_mask = s_train_xdata_mask.get_value()
s_train_xdata_mask.set_value(data_xdata_mask[:,inds,:])
start_time=time.time()
mse, extra = train(i % num_batches)
elapsed_time = time.time() - start_time
train_loss.append(mse)
msp = (data_xdata[:, n_batch * (i%num_batches):n_batch * (i%num_batches+1),:]**2).mean() #mean-squared power of reference
train_ref.append(msp)
train_time.append(elapsed_time)
print "Iteration:", i
if (loss_function=='MSEplusL1'):
train_mse.append(extra)
print "MSE + L1: ", mse
print "MSE : ", extra
print "NMSE : ", extra/msp
else:
print "MSE: ", mse
print "NMSE:", mse/msp
print "Time:", elapsed_time
print
if (i % iters_per_validCheck==0):
start_time=time.time()
mse, extra, xgen, ht, nmse_local, cost_steps = test()
elapsed_time = time.time() - start_time
msp = (test_xdata**2).mean()
print
print "TEST"
if (loss_function=='MSEplusL1'):
test_mse.append(extra)
print "MSE + L1: ", mse
print "MSE : ", extra
print "NMSE : ", extra/msp
else:
print "MSE: ", mse
print "NMSE global:", mse/msp
print "NMSE local:", nmse_local.mean()
print "Time:", elapsed_time
print
test_loss.append(mse)
test_ref.append(msp)
test_time.append(elapsed_time)
if mse < best_test_loss:
best_params = [p.get_value() for p in parameters]
best_rms = [r.get_value() for r in rmsprop]
best_test_loss = mse
best_xgen = xgen
best_ht = ht
best_nmse_local = nmse_local
else:
num_test_inc=num_test_inc+1
print "No improvement in test loss, %d of %d allowed" % (num_test_inc,num_allowed_test_inc)
print ""
if num_test_inc==num_allowed_test_inc:
print "Number of allowed test loss increments reached. Returning..."
print ""
return
save_vals = {'parameters': [p.get_value() for p in parameters],
'rmsprop': [r.get_value() for r in rmsprop],
'train_loss': train_loss,
'train_ref': train_ref,
'train_time': train_time,
'test_loss': test_loss,
'test_ref': test_ref,
'test_time': test_time,
'best_params': best_params,
'best_rms': best_rms,
'best_test_loss': best_test_loss,
'best_xgen': best_xgen,
#'best_ht': best_ht,
'best_nmse_local': best_nmse_local,
'model': model,
'stats': stats}
if (loss_function=='MSEplusL1'):
save_vals['train_mse']=train_mse
save_vals['test_mse']=test_mse
cPickle.dump(save_vals,
file(savefile, 'wb'),
cPickle.HIGHEST_PROTOCOL)
# run evaluation data
if offset_eval is not None:
print ""
print "Running forward model on evaluation data using best validation parameters..."
start_time=time.time()
mse, extra, xgen, ht, nmse_local, cost_steps = evalf()
elapsed_time = time.time() - start_time
print "Forward pass took %f seconds" % elapsed_time
msp = (eval_xdata**2).mean()
print "MSE=%f, ref=%f, NMSE=%f" % (mse,msp,mse/msp)
save_vals = {'parameters': [p.get_value() for p in parameters],
'eval_loss': mse,
'eval_ref': msp,
'eval_time': elapsed_time,
'best_params': best_params,
'xgen': xgen,
'ht': ht,
'model': model,
'stats': stats}
cPickle.dump(save_vals,
file(savefile+'_eval', 'wb'),
cPickle.HIGHEST_PROTOCOL)
if outdir is not None:
outdir=''.join(["/data1/prediction_audio_out/", outdir])
if not os.path.exists(outdir):
os.makedirs(outdir)
if not outdir[-1]=='/':
outdir=outdir + '/' #append a slash on the end
#write out synthetized audio
mse, extra, xgen, ht, nmse_local, cost_steps = test()
cPickle.dump({'ht': ht},
file('ht_cur', 'wb'),
cPickle.HIGHEST_PROTOCOL)
Yest=xgen
# Yest is shape nframMax x nutt x 2*nch*F
nutt=Yest.shape[1]
nsrc=1
for iutt in range(nutt):
#perform inverse STFT for estimates Yest
Yest_cur = np.transpose(np.squeeze(Yest[:,iutt,:]),(1,0))
# Yest_cur is 2*nch*F x nframMax
yestr = iAugSTFT(Yest_cur,F,1,flag_unwrap_phase,flag_noDiv=flag_noDiv,window=params_stft['window'],hop=params_stft['hop'])
# yestr is 1 x nsampl x nch
#perform inverse STFT for references test_y
test_xdata_cur = np.transpose(np.squeeze(test_xdata[:,iutt,:]),(1,0))
# test_y_cur is 2*nch*F x nframMax
yr = iAugSTFT(test_xdata_cur,F,1,flag_unwrap_phase,flag_noDiv=flag_noDiv,window=params_stft['window'],hop=params_stft['hop'])
# yr is 1 x nsampl x nch
#build parts of output wavfile
wavfile_in_cur=wavfiles_test[iutt]
wavfile_in_split=wavfile_in_cur.split('/')
filename=wavfile_in_split[-2] + '_' + wavfile_in_split[-1]
print 'Writing out wav for file %d of %d total...' % (iutt+1,nutt)
#write out audio of data example (to check the istft plumbing)
filename_out=filename
path_out=''.join((outdir,filename_out))
wavwrite(path_out,8e3,np.squeeze(yr[0,:,:]))
#write out audio file of reconstructed output
path_out=''.join((outdir,filename_out.replace('.wav','') + '_gen%d' % iutt, '.wav'))
wavwrite(path_out,8e3,np.squeeze(yestr[0,:,:]))
if __name__=="__main__":
parser = argparse.ArgumentParser(
description="training a model")
parser.add_argument("n_iter", type=int, default=20000)
parser.add_argument("n_batch", type=int, default=20)
parser.add_argument("n_hidden", type=int, default=512)
parser.add_argument("learning_rate", type=float, default=0.001)
parser.add_argument("savefile")
parser.add_argument("model", default='complex_RNN')
parser.add_argument("input_type", default='categorical')
parser.add_argument("out_every_t", default='False')
parser.add_argument("loss_function", default='MSE')
parser.add_argument("fold",default='fold1')
parser.add_argument("scene")
parser.add_argument("--n_reflections", default=8, help="number of reflections for CUE-RNN")
parser.add_argument("--flag_telescope", default=True, help="whether to use telescoping reflections (True) or full reflections (False)")
parser.add_argument("--nch", default=1, help="how many channels of audio input to use (default=1)")
parser.add_argument("--flag_unwrap_phase", default=True, help="remove window hop phase on target STFTs (default=True)")
parser.add_argument("--indir", default="audio_8khz", help="input directory for DCASE2016 audio file clips (default=audio_8khz)")
parser.add_argument("--outdir", default=None, help="output directory for reconstructed files (default=None)")
parser.add_argument("--dataset", default="timit", help="dataset to use (default=timit)")
parser.add_argument("--initfile", default=None, help="savefile to initialize generator network with (default=None)")
parser.add_argument("--flag_feed_forward", default=True, help="disable recurrent connections (default=True)")
parser.add_argument("--flag_generator", default=False, help="run as a generator i.e., inputs are random unit-variance circular complex Gaussian vectors (default=False)")
parser.add_argument("--downsample_train", default=1, help="downsampling factor for training data (default=1)")
parser.add_argument("--downsample_test", default=1, help="downsampling factor for test data (default=1)")
parser.add_argument("--time_steps", default=0, help="number of time steps for sequences when using dataset='synthgen' (default=0, not used if dataset does not equal 'synthgen')")
parser.add_argument("--n_Givens", default=0, help="number of Givens rotations to use for uRNN (default=0; if set to None or 0, ad-hoc parameterization of [Arjovsky, Shah, and Bengio 2015] is used)")
parser.add_argument("--prng_seed_Givens", default=52016, help="seed to initialize random number generator that determines which indices will be used in Givens parameterization (default=52016)")
parser.add_argument("--num_allowed_test_inc", default=10, help="number of allowed increases in the test error before stopping training (default=10)")
parser.add_argument("--iters_per_validCheck", default=20, help="number of training iterations between validation checks on test set (default=20)")
parser.add_argument("--flag_useFullW", default=False, help="use a full unitary matrix for W, using Stiefel manifold projected gradient for optimization; overrides n_Givens (default=False)")
parser.add_argument("--flag_onlyOptimW", default=True, help="if dataset is synthgen, only optimize W, and not the other parameters V, U, hidden_bias, out_bias (default=True)")
parser.add_argument("--lam", default=0.0, help="used as L1 regularization weight if loss_function is MSEplusL1 (default=0.0)")
parser.add_argument("--Vnorm", default=0.0, help="normalize rows of V to equal this number (default=0.0, which means no normalization is performed)")
parser.add_argument("--Unorm", default=0.0, help="normalize columnss of U to equal this number (default=0.0, which means no normalization is performed)")
parser.add_argument("--n_layers", default=1, help="number of RNN layers to use (default=1)")
parser.add_argument("--num_pred_steps", default=0, help="number of steps to predict ahead (default=0)")
parser.add_argument("--hidden_bias_mean", default=0.1, help="mean of initial hidden bias in all layers (default=0.1)")
parser.add_argument("--data_transform", default="", help="apply a transformation to the input and output data (default='', options: 'logmag')")
parser.add_argument("--bwe_frac", default=1.0, help="if less than 1.0, sets up data to do bandwidth expansion by using the lower (1-bwe_frac) portion of the data to predict the upper bwe_frac portion of the data (default=1.0)")
parser.add_argument("--data_normalization", default="none", help="compute statistics and normalize training and test data per feature dimension (default='none', options: any combination of lower-case 'mean', 'var', and/or 'perUtt'. Defaults to global statistics (i.e., computed over all time steps and sequences))")
parser.add_argument("--offset_eval", default=-1, help="starting index of eval set, uses downsample_test (default=None)")
parser.add_argument("--olap", default=50, help="overlap of STFT window, as number between 1 and 99. Only achieves perfect reconstruction for certain values, dependent on STFT window choice (e.g., 'sqrt_hann' with olap=50 gives PR) (default=50)")
parser.add_argument("--window", default='hann', help="STFT window, provided as a string. Options: 'hann', 'sqrt_hann' (default='hann')")
parser.add_argument("--flag_noDiv", default=0, help="if set, does not divide reconstructed time series by overlap-added squared window (default=0)")
parser.add_argument("--flag_noComplexConstraint", default=0, help="if set, relaxes complex constraint on input transform V and output transform U for complex_RNNs (default=0)")
parser.add_argument("--Winit", default="svd", help="Initialization method of synth_W in synthgen experiments. Options: 'svd', 'adhoc', 'adhoc2x' (default='svd')")
parser.add_argument("--seed", default=1234, help="random seed for LSTM and complex_RNN (default=1234)")
parser.add_argument("--optim_alg", default="rmsprop", help="optimization algorithm (default='rmsprop')")
parser.add_argument("--n_utt_eval_spec", default=-1, help="number of evaluation utterances, from offset_eval (default=-1)")
args = parser.parse_args()
dict = vars(args)
kwargs = {'n_iter': dict['n_iter'],
'n_batch': dict['n_batch'],
'n_hidden': dict['n_hidden'],
'learning_rate': np.float32(dict['learning_rate']),
'savefile': dict['savefile'],
'model': dict['model'],
'input_type': dict['input_type'],
'out_every_t': 'True'==dict['out_every_t'],
'loss_function': dict['loss_function'],
'fold': dict['fold'],
'scene': dict['scene'],
'n_reflections': int(args.n_reflections),
'flag_telescope': bool(np.int(args.flag_telescope)),
'nch': int(args.nch),
'flag_unwrap_phase': bool(np.int(args.flag_unwrap_phase)),
'indir': dict['indir'],
'outdir': dict['outdir'],
'dataset': dict['dataset'],
'initfile': dict['initfile'],
'flag_feed_forward': bool(np.int(args.flag_feed_forward)),
'flag_generator': bool(np.int(args.flag_generator)),
'downsample_train': int(args.downsample_train),
'downsample_test': int(args.downsample_test),
'time_steps': int(args.time_steps),
'n_Givens': int(args.n_Givens),
'prng_seed_Givens': int(args.prng_seed_Givens),
'num_allowed_test_inc': int(args.num_allowed_test_inc),
'iters_per_validCheck': int(args.iters_per_validCheck),
'flag_useFullW': bool(np.int(args.flag_useFullW)),
'flag_onlyOptimW': bool(np.int(args.flag_onlyOptimW)),
'lam': np.float32(dict['lam']),
'Vnorm': np.float32(dict['Vnorm']),
'Unorm': np.float32(dict['Unorm']),
'n_layers': int(args.n_layers),
'num_pred_steps': int(args.num_pred_steps),
'hidden_bias_mean': np.float32(dict['hidden_bias_mean']),
'data_transform': dict['data_transform'],
'bwe_frac': np.float32(dict['bwe_frac']),
'data_normalization': dict['data_normalization'],
'offset_eval': int(args.offset_eval),
'olap': np.float32(dict['olap']),
'window': dict['window'],
'flag_noDiv': bool(np.int(args.flag_noDiv)),
'flag_noComplexConstraint': bool(np.int(args.flag_noComplexConstraint)),
'Winit': dict['Winit'],
'seed': int(args.seed),
'optim_alg': dict['optim_alg'],
'n_utt_eval_spec': int(args.n_utt_eval_spec)}
main(**kwargs)