https://github.com/kenkov/seq2seq
Tip revision: 446045212ac8c12535cd1f2113e9ba37a5306207 authored by Noriyuki Abe on 26 February 2016, 21:45:45 UTC
Merge branch 'dev'
Merge branch 'dev'
Tip revision: 4460452
train.py
#! /usr/bin/env python
# coding:utf-8
if __name__ == "__main__":
import argparse
from seq2seq import train_encoder_decoder
from util import load_dictionary, load_sentence
import configparser
import os
from chainer import serializers
from gensim.models import word2vec
import logging
import relu_rnn
logging.basicConfig(
format='%(asctime)s : %(levelname)s : %(message)s',
level=logging.INFO
)
# GPU config
parser = argparse.ArgumentParser()
parser.add_argument('config_file', metavar='config_file', type=str,
help='config file')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--type', '-t', default="relu", type=str,
help='GPU ID (negative value indicates CPU)')
args = parser.parse_args()
gpu_flag = args.gpu if args.gpu >= 0 else -1
config_file = args.config_file
parser_config = configparser.ConfigParser()
parser_config.read(config_file)
# params
config = parser_config["CONFIG"]
model_dir = config.get("model_dir")
dict_file = config.get("dict_file")
sent_file = config.get("sent_file")
conv_file = config.get("conv_file")
word2vec_init = config.getboolean("word2vec_init")
word2vec_model_file = config.get("word2vec_model_file")
min_freq = config.getint("min_freq")
n_units = config.getint("n_units")
epoch_size = config.getint("epoch_size")
batch_size = config.getint("batch_size")
dropout = config.getboolean("dropout")
print("### options ###")
print("model_dir: {}".format(model_dir))
print("sent_file: {}".format(sent_file))
print("conv_file: {}".format(conv_file))
print("dict_file: {}".format(dict_file))
print("word2vec_init: {}".format(word2vec_init))
print("word2vec_model_file: {}".format(word2vec_model_file))
print("min_freq: {}".format(min_freq))
print("n_units: {}".format(n_units))
print("epoch_size: {}".format(epoch_size))
print("batch_size: {}".format(batch_size))
print("dropout: {}".format(dropout))
print("##############")
# 辞書
if os.path.exists(dict_file):
dictionary = load_dictionary(dict_file)
else:
from util import create_dictionary
dictionary = create_dictionary(
[sent_file],
min_freq=min_freq
)
dictionary.save(dict_file)
# Prepare encoder RNN model
dim = len(dictionary.keys())
model_type = args.type
if model_type == "relu":
model = relu_rnn.Classifier(
relu_rnn.ReLURNN(
embed_dim=dim,
n_units=n_units,
gpu=args.gpu
)
)
elif model_type == "lstm":
import lstm
model = lstm.Classifier(
lstm.LSTM(
embed_dim=dim,
n_units=n_units,
gpu=args.gpu
)
)
else:
raise Exception("model argment should be relu or lstm")
# load model
init_model_name = os.path.join(
model_dir,
"model.npz"
)
if os.path.exists(init_model_name):
serializers.load_npz(init_model_name, model)
print("load model {}".format(init_model_name))
elif word2vec_init:
# initialize embedding layer by word2vec
import numpy as np
if os.path.exists(word2vec_model_file):
print("load word2vec model")
word2vec_model = word2vec.Word2Vec.load(word2vec_model_file)
else:
print("start learning word2vec model")
word2vec_model = word2vec.Word2Vec(
load_sentence(sent_file),
size=n_units,
window=5,
min_count=1,
workers=4
)
print("save word2vec model")
word2vec_model.save(word2vec_model_file)
# initialize word embedding layer with word2vec
initial_W = np.array([
word2vec_model[dictionary[wid]]
if dictionary[wid] in word2vec_model
else np.array(
[np.random.random() for _ in range(n_units)],
dtype=np.float32
)
for wid in range(dim)],
dtype=np.float32
)
not_found_words = []
for wid in range(dim):
if dictionary[wid] not in word2vec_model:
not_found_words.append(dictionary[wid])
print("{} are not found in word2vec model".format(not_found_words))
model.predictor.set_word_embedding(initial_W)
# print(initial_W)
print("finish initializing word embedding with word2vec")
train_encoder_decoder(
model,
dictionary,
conv_file,
model_dir,
epoch_size,
batch_size,
dropout,
gpu=gpu_flag
)