https://github.com/zhouyiwei/cc
Tip revision: b0efa7ecacd6747f20711b78cc5d9c2d7ec4adbc authored by zhouyiwei on 25 June 2018, 18:16:30 UTC
Create README.md
Create README.md
Tip revision: b0efa7e
train.py
import tensorflow as tf
from utils import *
from sklearn.model_selection import KFold
from models import *
import time
import datetime
tf.app.flags.DEFINE_string("dir", "/data", "folder directory")
tf.app.flags.DEFINE_string("training_file", "clickbait17-validation-170630", "Training data file")
tf.app.flags.DEFINE_string("validation_file", "clickbait17-train-170331", "Validation data file")
tf.app.flags.DEFINE_integer("epochs", 20, "epochs")
tf.app.flags.DEFINE_integer("batch_size", 32, "batch_size")
tf.app.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes")
tf.app.flags.DEFINE_integer("num_filters", 100, "Number of filters per filter size")
tf.app.flags.DEFINE_float("dropout_rate_hidden", 0.5, "Dropout rate of hidden layer")
tf.app.flags.DEFINE_float("dropout_rate_cell", 0.3, "Dropout rate of rnn cell")
tf.app.flags.DEFINE_float("dropout_rate_embedding", 0.2, "Dropout rate of word embedding")
tf.app.flags.DEFINE_integer("state_size", 64, "state_size")
tf.app.flags.DEFINE_integer("hidden_size", 0, "hidden_size")
tf.app.flags.DEFINE_string("timestamp", "0715", "Timestamp")
tf.app.flags.DEFINE_integer("y_len", 4, "how to interpret the annotation")
tf.app.flags.DEFINE_string("model", "SAN", "which model to use")
tf.app.flags.DEFINE_boolean("use_target_description", False, "whether to use the target description as input")
tf.app.flags.DEFINE_boolean("use_image", False, "whether to use the image as input")
tf.app.flags.DEFINE_float("learning_rate", 0.005, "learning rate")
tf.app.flags.DEFINE_integer("embedding_size", 100, "embedding size")
tf.app.flags.DEFINE_float("gradient_clipping_value", 2, "gradient clipping value")
FLAGS = tf.app.flags.FLAGS
def main(argv=None):
np.random.seed(81)
word2id, embedding = load_embeddings(fp=os.path.join(FLAGS.dir, "glove.6B."+str(FLAGS.embedding_size)+"d.txt"), embedding_size=FLAGS.embedding_size)
with open(os.path.join(FLAGS.dir, 'word2id.json'), 'w') as fout:
json.dump(word2id, fp=fout)
# vocab_size = embedding.shape[0]
# embedding_size = embedding.shape[1]
ids, post_texts, truth_classes, post_text_lens, truth_means, target_descriptions, target_description_lens, image_features = read_data(word2id=word2id, fps=[os.path.join(FLAGS.dir, FLAGS.training_file), os.path.join(FLAGS.dir, FLAGS.validation_file)], y_len=FLAGS.y_len, use_target_description=FLAGS.use_target_description, use_image=FLAGS.use_image)
post_texts = np.array(post_texts)
truth_classes = np.array(truth_classes)
post_text_lens = np.array(post_text_lens)
truth_means = np.array(truth_means)
shuffle_indices = np.random.permutation(np.arange(len(post_texts)))
post_texts = post_texts[shuffle_indices]
truth_classes = truth_classes[shuffle_indices]
post_text_lens = post_text_lens[shuffle_indices]
truth_means = truth_means[shuffle_indices]
max_post_text_len = max(post_text_lens)
print max_post_text_len
post_texts = pad_sequences(post_texts, max_post_text_len)
target_descriptions = np.array(target_descriptions)
target_description_lens = np.array(target_description_lens)
target_descriptions = target_descriptions[shuffle_indices]
target_description_lens = target_description_lens[shuffle_indices]
max_target_description_len = max(target_description_lens)
print max_target_description_len
target_descriptions = pad_sequences(target_descriptions, max_target_description_len)
image_features = np.array(image_features)
data = np.array(list(zip(post_texts, truth_classes, post_text_lens, truth_means, target_descriptions, target_description_lens, image_features)))
kf = KFold(n_splits=5)
round = 1
val_scores = []
val_accs = []
for train, validation in kf.split(data):
train_data, validation_data = data[train], data[validation]
g = tf.Graph()
with g.as_default() as g:
tf.set_random_seed(81)
with tf.Session(graph=g) as sess:
if FLAGS.model == "DAN":
model = DAN(x1_maxlen=max_post_text_len, y_len=len(truth_classes[0]), x2_maxlen=max_target_description_len, embedding=embedding, filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), num_filters=FLAGS.num_filters, hidden_size=FLAGS.hidden_size, state_size=FLAGS.state_size, x3_size=len(image_features[0]))
if FLAGS.model == "CNN":
model = CNN(x1_maxlen=max_post_text_len, y_len=len(truth_classes[0]), x2_maxlen=max_target_description_len, embedding=embedding, filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), num_filters=FLAGS.num_filters, hidden_size=FLAGS.hidden_size, state_size=FLAGS.state_size, x3_size=len(image_features[0]))
if FLAGS.model == "BiRNN":
model = BiRNN(x1_maxlen=max_post_text_len, y_len=len(truth_classes[0]), x2_maxlen=max_target_description_len, embedding=embedding, filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), num_filters=FLAGS.num_filters, hidden_size=FLAGS.hidden_size, state_size=FLAGS.state_size, x3_size=len(image_features[0]))
if FLAGS.model == "SAN":
model = SAN(x1_maxlen=max_post_text_len, y_len=len(truth_classes[0]), x2_maxlen=max_target_description_len, embedding=embedding, filter_sizes=list(map(int, FLAGS.filter_sizes.split(","))), num_filters=FLAGS.num_filters, hidden_size=FLAGS.hidden_size, state_size=FLAGS.state_size, x3_size=len(image_features[0]), attention_size=2*FLAGS.state_size)
global_step = tf.Variable(0, name="global_step", trainable=False)
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
grads_and_vars = optimizer.compute_gradients(model.loss)
if FLAGS.gradient_clipping_value:
grads_and_vars = [(tf.clip_by_value(grad, -FLAGS.gradient_clipping_value, FLAGS.gradient_clipping_value), var) for grad, var in grads_and_vars]
train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)
out_dir = os.path.join(FLAGS.dir, "runs", FLAGS.timestamp)
# loss_summary = tf.summary.scalar("loss", model.loss)
# acc_summary = tf.summary.scalar("accuracy", model.accuracy)
# train_summary_op = tf.summary.merge([loss_summary, acc_summary])
# train_summary_dir = os.path.join(out_dir, "summaries", "train")
# train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)
# val_summary_op = tf.summary.merge([loss_summary, acc_summary])
# val_summary_dir = os.path.join(out_dir, "summaries", "validation")
# val_summary_writer = tf.summary.FileWriter(val_summary_dir, sess.graph)
checkpoint_dir = os.path.join(out_dir, "checkpoints")
checkpoint_prefix = os.path.join(checkpoint_dir, FLAGS.model+str(round))
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
def train_step(input_x1, input_y, input_x1_len, input_z, input_x2, input_x2_len, input_x3):
feed_dict = {model.input_x1: input_x1,
model.input_y: input_y,
model.input_x1_len: input_x1_len,
model.input_z: input_z,
model.dropout_rate_hidden: FLAGS.dropout_rate_hidden,
model.dropout_rate_cell: FLAGS.dropout_rate_cell,
model.dropout_rate_embedding: FLAGS.dropout_rate_embedding,
model.batch_size: len(input_x1),
model.input_x2: input_x2,
model.input_x2_len: input_x2_len,
model.input_x3: input_x3}
_, step, loss, mse, accuracy = sess.run([train_op, global_step, model.loss, model.mse, model.accuracy], feed_dict)
time_str = datetime.datetime.now().isoformat()
print("{}: step {}, loss {:g}, mse {:g}, acc {:g}".format(time_str, step, loss, mse, accuracy))
# train_summary_writer.add_summary(summaries, step)
def validation_step(input_x1, input_y, input_x1_len, input_z, input_x2, input_x2_len, input_x3, writer=None):
feed_dict = {model.input_x1: input_x1,
model.input_y: input_y,
model.input_x1_len: input_x1_len,
model.input_z: input_z,
model.dropout_rate_hidden: 0,
model.dropout_rate_cell: 0,
model.dropout_rate_embedding: 0,
model.batch_size: len(input_x1),
model.input_x2: input_x2,
model.input_x2_len: input_x2_len,
model.input_x3: input_x3}
step, loss, mse, accuracy = sess.run([global_step, model.loss, model.mse, model.accuracy], feed_dict)
time_str = datetime.datetime.now().isoformat()
print("{}: step {}, loss {:g}, mse {:g}, acc {:g}".format(time_str, step, loss, mse, accuracy))
# if writer:
# writer.add_summary(summaries, step)
return mse, accuracy
print("\nValidation: ")
post_text_val, truth_class_val, post_text_len_val, truth_mean_val, target_description_val, target_description_len_val, image_feature_val= zip(*validation_data)
validation_step(post_text_val, truth_class_val, post_text_len_val, truth_mean_val, target_description_val, target_description_len_val, image_feature_val)
print("\n")
min_mse_val = np.inf
acc = np.inf
for i in range(FLAGS.epochs):
batches = get_batch(train_data, FLAGS.batch_size)
for batch in batches:
post_text_batch, truth_class_batch, post_text_len_batch, truth_mean_batch, target_description_batch, target_description_len_batch, image_feature_batch = zip(*batch)
train_step(post_text_batch, truth_class_batch, post_text_len_batch, truth_mean_batch, target_description_batch, target_description_len_batch, image_feature_batch)
print("\nValidation: ")
mse_val, acc_val = validation_step(post_text_val, truth_class_val, post_text_len_val, truth_mean_val, target_description_val, target_description_len_val, image_feature_val)
print("\n")
if mse_val < min_mse_val:
min_mse_val = mse_val
acc = acc_val
# saver.save(sess, checkpoint_prefix)
round += 1
val_scores.append(min_mse_val)
val_accs.append(acc)
print np.mean(val_scores)
print np.mean(val_accs)
if __name__ == "__main__":
tf.app.run()