Revision e1467a79dc6580ae009d827b5e6f274faff3b339 authored by liqunfu on 27 March 2020, 21:42:04 UTC, committed by GitHub on 27 March 2020, 21:42:04 UTC
2 parent s c7bc93f + a2055f6
Raw File
download_data.py
# ==============================================================================
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================

# This program downloads training, validation and test data and creates additional files 
# with token-ids and frequencies.

import os, sys, tarfile, operator
try:
    from urllib.request import urlretrieve
except ImportError:
    from urllib import urlretrieve
from io import open

# accumulate word counts in dictionary
def add_to_count(word, word2Count):
    if word in word2Count:
        word2Count[word] += 1
    else:
        word2Count[word] = 1

# for a text file returns a dictionary with the frequency of each word
def count_words_in_file(path):
    with open(path,'r') as f:
        word2count = {}
        for line in f:
            words = line.split()
            for word in words:
                add_to_count(word, word2count)
        return word2count

# from a dictionary mapping words to counts creates two files: 
# * a vocabulary file containing all words sorted by decreasing frequency, one word per line
# * a frequency file containing the frequencies of these word, one number per line.
def write_vocab_and_frequencies(word2count, vocab_file_path, freq_file_path, word2count_file_path, word2id_file_path):
    vocab_file = open(vocab_file_path,'w', newline='\r\n')
    freq_file = open(freq_file_path,'w', newline='\r\n')
    word2count_file = open(word2count_file_path,'w', newline='\r\n')
    word2id_file = open(word2id_file_path,'w', newline='\r\n')
    sorted_entries = sorted(word2count.items(), key = operator.itemgetter(1) , reverse = True)
    
    id=int(0)
    for word, freq in sorted_entries:
        vocab_file.write(word+"\n")
        freq_file.write(u"%i\n" % freq)
        word2count_file.writelines(u"%s\t%i\n" % (word, freq))
        word2id_file.writelines(u"%s\t%i\n" % (word, id))
        id +=1

    #close the files
    vocab_file.close()
    freq_file.close()
    word2count_file.close()

#copy txt file and append '<eos>' at end of each line
def append_eos_and_trim(from_path, to_path, max_lines_in_output = None):
    with open(from_path,'r') as f:
        lines = f.read().splitlines()

    with open(to_path,'w') as f:
        count=0
        for line in lines:
            count += 1
            if max_lines_in_output != None and count > max_lines_in_output:
                break

            f.write(line + "<eos>\n")


class Paths(object):

    # Relative paths of the data file in the downloaded tar file
    tar_path_test       = './simple-examples/data/ptb.test.txt'
    tar_path_train      = './simple-examples/data/ptb.train.txt'
    tar_path_validation = './simple-examples/data/ptb.valid.txt'

    tmp_dir  = './tmp/'

    # final path of the data files
    data_dir = './ptb/'
    test       = os.path.join(data_dir, 'test.txt')
    train      = os.path.join(data_dir, 'train.txt')
    validation = os.path.join(data_dir, 'valid.txt')

    # files derived from the data files
    tokens          = os.path.join(data_dir, 'vocab.txt')
    frequencies     = os.path.join(data_dir, 'freq.txt')
    token2frequency = os.path.join(data_dir, 'token2freq.txt')
    token2id        = os.path.join(data_dir, 'token2id.txt')


if __name__=='__main__':

    os.chdir(os.path.abspath(os.path.dirname(__file__)))

    # downloading the data
    url ="http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz"

    tmpGz = "tmp.tgz"
    if not os.path.isfile(tmpGz):
        print("downloading " + url + " to " + tmpGz)
        urlretrieve(url, tmpGz)

    # extracting the files we need from the tarfile
    fileReader=tarfile.open(tmpGz, 'r') 
    print("Extracting files into: " + Paths.tmp_dir)
    fileReader.extract(Paths.tar_path_test,       path = Paths.tmp_dir)
    fileReader.extract(Paths.tar_path_train,      path = Paths.tmp_dir)
    fileReader.extract(Paths.tar_path_validation, path = Paths.tmp_dir)

    print('creating final data files in directory:' + Paths.data_dir)
    os.mkdir(Paths.data_dir)
    append_eos_and_trim(os.path.join(Paths.tmp_dir, Paths.tar_path_test),       Paths.test)
    append_eos_and_trim(os.path.join(Paths.tmp_dir, Paths.tar_path_train),      Paths.train)
    append_eos_and_trim(os.path.join(Paths.tmp_dir, Paths.tar_path_validation), Paths.validation, max_lines_in_output = 50)

    fileReader.close()

    #removing the temporary file
    os.remove(tmpGz)

    # from the training file generate a number of helper files
    word2count = count_words_in_file(Paths.train)
    write_vocab_and_frequencies(word2count, Paths.tokens, Paths.frequencies, Paths.token2frequency, Paths.token2id)
back to top