https://github.com/kenkov/seq2seq
Tip revision: 446045212ac8c12535cd1f2113e9ba37a5306207 authored by Noriyuki Abe on 26 February 2016, 21:45:45 UTC
Merge branch 'dev'
Merge branch 'dev'
Tip revision: 4460452
test.py
#! /usr/bin/env python
# coding:utf-8
if __name__ == "__main__":
import sys
import argparse
from seq2seq import decode
from util import load_dictionary
import configparser
import os
from chainer import serializers
# GPU config
parser = argparse.ArgumentParser()
parser.add_argument('config_file', metavar='config_file', type=str,
help='config file')
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
parser.add_argument('--type', '-t', default="relu", type=str,
help='GPU ID (negative value indicates CPU)')
args = parser.parse_args()
gpu_flag = True if args.gpu >= 0 else False
config_file = args.config_file
parser_config = configparser.ConfigParser()
parser_config.read(config_file)
config = parser_config["CONFIG"]
# config["SEPARATOR"] = bytes(
# config["DEFAULT"]["SEPARATOR"], "utf-8"
# ).decode("unicode_escape")
# params
model_dir = config["model_dir"]
n_units = int(config["n_units"])
# load conversation sentences
dictionary = load_dictionary(config["dict_file"])
# Prepare encoder RNN model
dim = len(dictionary.keys())
model_type = args.type
if model_type == "relu":
import relu_rnn
model = relu_rnn.Classifier(
relu_rnn.ReLURNN(
embed_dim=dim,
n_units=int(config["n_units"]),
gpu=args.gpu
)
)
elif model_type == "lstm":
import lstm
model = lstm.Classifier(
lstm.LSTM(
embed_dim=dim,
n_units=int(config["n_units"]),
gpu=args.gpu
)
)
else:
raise Exception("model argment should be relu or lstm")
# load model
init_model_name = os.path.join(
model_dir,
"model.npz"
)
if os.path.exists(init_model_name):
serializers.load_npz(init_model_name, model)
print("load model {}".format(init_model_name))
else:
raise Exception("learn model first")
for text in (_.strip() for _ in sys.stdin):
ws = text.split()
print("> {}".format(" ".join(ws)))
for order in range(1, 5):
decoded_words = decode(
ws,
model,
model,
dictionary,
order=order
)
answer_text = "".join(decoded_words[1:-1])
print("{}".format(answer_text))