https://github.com/azubiaga/tweet-country-classification
Tip revision: 56cf3e3b9d1c4637f5dbc49eb254916d74b00ae7 authored by Arkaitz Zubiaga on 14 June 2017, 09:10:05 UTC
updated readme file
updated readme file
Tip revision: 56cf3e3
run.classifier.py
from __future__ import print_function
import numpy as np
from scipy.sparse import csc_matrix
from sklearn import metrics
import sys, os
from sklearn.svm import LinearSVC
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.externals import joblib
import json
import operator
from pprint import pprint
import gensim
import re
def get_top_labels_and_feats(filename, top, topinc):
featids = {'content': [], 'name': [], 'tz': [], 'ulang': [], 'uloc': []}
topcountries = {}
if os.path.exists('feats/topcountries' + str(top)) and os.path.exists('feats/content' + str(topinc)) and os.path.exists('feats/name' + str(topinc)) and os.path.exists('feats/tz' + str(topinc)) and os.path.exists('feats/ulang' + str(topinc)) and os.path.exists('feats/uloc' + str(topinc)):
with open('feats/topcountries' + str(top), 'r') as fh:
for line in fh:
data = line.strip().split('\t')
topcountries[data[0]] = int(data[1])
with open('feats/content' + str(topinc), 'r') as fh:
featids['content'] = fh.read().splitlines()
with open('feats/name' + str(topinc), 'r') as fh:
featids['name'] = fh.read().splitlines()
with open('feats/tz' + str(topinc), 'r') as fh:
featids['tz'] = fh.read().splitlines()
with open('feats/ulang' + str(topinc), 'r') as fh:
featids['ulang'] = fh.read().splitlines()
with open('feats/uloc' + str(topinc), 'r') as fh:
featids['uloc'] = fh.read().splitlines()
else:
countries = {}
feats = {'content': {}, 'name': {}, 'tz': {}, 'ulang': {}, 'uloc': {}}
linecount = 0
with open(filename, 'r') as fh:
for line in fh:
linecount += 1
tweet = json.loads(line)
cc = tweet['coordinates']['location']['address']['country_code']
countries[cc] = 1 + countries.get(cc, 0)
for ctoken in re.sub(r'([^\s\w]|_)+', '', tweet['text'].lower()).split(' '):
if ctoken != 'rt' and not 'http' in ctoken:
feats['content'][ctoken] = 1 + feats['content'].get(ctoken, 0)
for nametoken in re.sub(r'([^\s\w]|_)+', '', tweet['user']['name'].lower()).split(' '):
feats['name'][nametoken] = 1 + feats['name'].get(nametoken, 0)
feats['tz'][str(tweet['user']['time_zone'])] = 1 + feats['tz'].get(str(tweet['user']['time_zone']), 0)
feats['ulang'][tweet['user']['lang']] = 1 + feats['ulang'].get(tweet['user']['lang'], 0)
for loctoken in re.sub(r'([^\s\w]|_)+', '', tweet['user']['location'].lower()).split(' '):
feats['uloc'][loctoken] = 1 + feats['uloc'].get(loctoken, 0)
print('Reading features: ' + str(linecount), end='\r')
print('Reading features: ' + str(linecount), end='\n')
cfeatids = sorted(feats['content'].items(), key=operator.itemgetter(1), reverse=True)
inc = 0
for feat, count in cfeatids:
inc += 1
if inc <= topinc:
featids['content'].append(feat)
namefeatids = sorted(feats['name'].items(), key=operator.itemgetter(1), reverse=True)
inc = 0
for feat, count in namefeatids:
inc += 1
if inc <= topinc:
featids['name'].append(feat)
featids['tz'] = feats['tz'].keys()
featids['ulang'] = feats['ulang'].keys()
ulocfeatids = sorted(feats['uloc'].items(), key=operator.itemgetter(1), reverse=True)
inc = 0
for feat, count in ulocfeatids:
inc += 1
if inc <= topinc:
featids['uloc'].append(feat)
with open('feats/content' + str(topinc), 'wb') as fw:
for f in featids['content']:
fw.write(f + '\n')
with open('feats/name' + str(topinc), 'wb') as fw:
for f in featids['name']:
fw.write(f + '\n')
with open('feats/tz' + str(topinc), 'wb') as fw:
for f in featids['tz']:
fw.write(f + '\n')
with open('feats/ulang' + str(topinc), 'wb') as fw:
for f in featids['ulang']:
fw.write(f + '\n')
with open('feats/uloc' + str(topinc), 'wb') as fw:
for f in featids['uloc']:
fw.write(f + '\n')
scountries = sorted(countries.items(), key=operator.itemgetter(1), reverse=True)
ccount = 0
with open('feats/topcountries' + str(top), 'wb') as fw:
for scountry, value in scountries:
ccount += 1
if ccount <= top:
topcountries[scountry] = ccount
fw.write(scountry + '\t' + str(ccount) + '\n')
return topcountries, featids
def build_matrix(filename, featureids, featurelengths, toplabels, istest = 0):
colcount = featurelengths['content'] + featurelengths['name'] + featurelengths['tz'] + featurelengths['ulang'] + featurelengths['uloc']
linecount = 0
rows = []
cols = []
gt = []
values = []
with open(filename, 'r') as fh:
for line in fh:
tweet = json.loads(line)
if istest == 1 or tweet['coordinates']['location']['address']['country_code'] in toplabels:
for ctoken in list(set(re.sub(r'([^\s\w]|_)+', '', tweet['text'].lower()).split(' '))):
if ctoken in featureids['content']:
rows.append(linecount)
cols.append(featureids['content'][ctoken])
values.append(1)
for nametoken in list(set(re.sub(r'([^\s\w]|_)+', '', tweet['user']['name'].lower()).split(' '))):
if nametoken in featureids['name']:
rows.append(linecount)
cols.append(featureids['name'][nametoken] + featurelengths['content'])
values.append(1)
if str(tweet['user']['time_zone']) in featureids['tz']:
rows.append(linecount)
cols.append(featureids['tz'][str(tweet['user']['time_zone'])] + featurelengths['content'] + featurelengths['name'])
values.append(1)
if tweet['user']['lang'] in featureids['ulang']:
rows.append(linecount)
cols.append(featureids['ulang'][tweet['user']['lang']] + featurelengths['content'] + featurelengths['name'] + featurelengths['tz'])
values.append(1)
for loctoken in list(set(re.sub(r'([^\s\w]|_)+', '', tweet['user']['location'].lower()).split(' '))):
if loctoken in featureids['uloc']:
rows.append(linecount)
cols.append(featureids['uloc'][loctoken] + featurelengths['content'] + featurelengths['name'] + featurelengths['tz'] + featurelengths['ulang'])
values.append(1)
if istest == 0:
gt.append(toplabels[tweet['coordinates']['location']['address']['country_code']])
else:
gt.append(0)
linecount += 1
print('Loading data: ' + str(linecount), end='\r')
sys.stdout.flush()
print('Loading data: ' + str(linecount), end='\n')
# TODO: tlang
row = np.asarray(rows)
col = np.asarray(cols)
data = np.asarray(values)
return (csc_matrix((data, (row, col)), shape=(linecount, colcount)), gt)
if len(sys.argv) < 4:
print("No dataset and classifier specified.")
sys.exit()
testdata = sys.argv[1]
classifier = sys.argv[2]
output = sys.argv[3]
top = int(sys.argv[4])
topinc = int(sys.argv[5])
if not os.path.exists(testdata):
print("Dataset not found.")
sys.exit()
trainfile = 'training-data.json'
toplabels, featids = get_top_labels_and_feats(trainfile, top, topinc)
featureids = {'content': {}, 'name': {}, 'tz': {}, 'ulang': {}, 'uloc': {}}
featurelengths = {'content': 0, 'name': 0, 'tz': 0, 'ulang': 0, 'uloc': 0}
for ftype in ['content', 'name', 'tz', 'ulang', 'uloc']:
fid = 0
for feat in featids[ftype]:
featureids[ftype][feat] = fid
fid = len(featureids[ftype])
featurelengths[ftype] = len(featureids[ftype])
if not os.path.exists('models/' + classifier + '-' + str(top) + '-' + str(topinc) + '.pkl'):
train_data, train_gt = build_matrix(trainfile, featureids, featurelengths, toplabels)
print("#1 Train data loaded.")
if classifier == "svm":
cmodel = LinearSVC()
if classifier == "nb":
cmodel = GaussianNB()
if classifier == "knn":
cmodel = KNeighborsClassifier()
if classifier == "decisiontree":
cmodel = DecisionTreeClassifier()
if classifier == "randomforest":
cmodel = RandomForestClassifier()
if classifier == "maxent":
cmodel = LogisticRegression()
cmodel.fit(train_data, train_gt)
_ = joblib.dump(cmodel, 'models/' + classifier + '-' + str(top) + '-' + str(topinc) + '.pkl', compress=9)
print("#2 Model created.")
else:
cmodel = joblib.load('models/' + classifier + '-' + str(top) + '-' + str(topinc) + '.pkl')
test_data, test_gt = build_matrix(testdata, featureids, featurelengths, toplabels, 1)
print("#3 Test data loaded.")
predicted = cmodel.predict(test_data)
print("#4 Classification done.")
key = 0
allgt = 0
matchinggt = 0
with open(testdata, 'r') as fh, open(output, 'wb') as fw:
for line in fh:
tweet = json.loads(line)
tweet['country'] = list(toplabels.keys())[list(toplabels.values()).index(predicted[key])]
fw.write(json.dumps(tweet) + '\n')
key += 1
if 'place' in tweet and tweet['place'] != None and 'country_code' in tweet['place']:
allgt += 1
if tweet['country'].lower() == tweet['place']['country_code'].lower():
matchinggt += 1
print('Accuracy: ' + str(matchinggt) + '/' + str(allgt) + ' -- ' + str(float(matchinggt) / allgt))