https://github.com/mskwark/PconsC3
Tip revision: e84ffb364118dac7179817653ef00264de5f2475 authored by Arne Elofsson on 18 May 2018, 13:50:56 UTC
Merge branch 'master' of github.com:/mskwark/PconsC3
Merge branch 'master' of github.com:/mskwark/PconsC3
Tip revision: e84ffb3
predict-simultaneous.py
#!/usr/bin/env python2.7
import joblib
import sys
import os, random, time
import numpy as np
import multiprocessing
from multiprocessing.pool import ThreadPool
from multiprocessing import Value, Array, Pool
if len(sys.argv) != 11:
print 'Usage: ' + sys.argv[0] + ' <files>'
print 'That is:'
print ' GaussDCA prediction'
print ' plmDCA.jl predictions'
print ' ML contact predictions'
print ' NetSurf RSA'
print ' SS file'
print ' Alignment stats'
print ' Alignment'
print ' Forest location'
print ' MaxDepth'
print ' Outfile'
sys.exit(1)
files = sys.argv[1:]
maxdepth = files[8]
# maxdepth = -1
if maxdepth < 0:
forestlocation = files[7] + '/tlayer{:d}'
else:
forestlocation = files[7] + '/tlayer{:d}-' + str(maxdepth)
# maximum time per layer
maxtime = pow(10,6)
# fraction of trees to use (prediction time scales linearly with the number of trees,
# while expected precision is roughly the same for values > 0.3
treedepth = 100
treefraction = 1
print forestlocation.format(0)
if not os.path.exists(forestlocation.format(0)):
forestlocation = os.path.dirname(os.path.realpath(__file__)) + '/tlayer{:d}'
print forestlocation
for i in range(5):
abort = False
if not os.path.exists(forestlocation.format(i) + '/tree.list'):
sys.stderr.write('Forest data for layer {:d} is missing.\n'.format(i))
abort = True
if abort:
sys.exit(0)
def parsePSIPRED(f):
SSdict = {}
try:
x = open(f).read().split('\n')
except:
return SSdict
for l in x:
y = l.split()
if len(y) != 6:
continue
i = int(y[0])
SSdict[i] = [float(y[3]), float(y[4]), float(y[5])]
return SSdict
def parseNetSurfP(f):
netSurfdict = {}
for l in open(f).readlines():
al = []
x = l.split()
if l.find('#') == 0:
continue
if l[0] not in ['B', 'E']:
y = ['E']
y.extend(x)
x = y
for y in [4,6,7, 8, 9]:
al.append(float(x[y]) )
netSurfdict[ int(x[3] )] = al
return netSurfdict
def parsePSSM(alignment):
pssm = {}
one2number = 'ARNDCEQGHILKMFPSTWYV-'
bi = [ 0.0825, 0.0553, 0.0406, 0.0545, 0.0137, 0.0393, 0.0675, 0.0707, 0.0227, 0.0595, 0.0966, 0.0584, 0.0242, 0.0386, 0.0470, 0.0657, 0.0534, 0.0108, 0.0292, 0.0687 ]
b = {}
for i in one2number[:-1]:
b[i] = bi[one2number.find(i)]
freqs = {}
seqcount = 0.
gapcount = 0
coverage = []
for l in open(alignment):
if l.find('>') > -1:
continue
x = l.strip()
if len(x) < 3:
continue
seqcount += 1
coverage.append( (len(x) - x.count('-'))/float(len(x)))
for i in range(len(x)):
try:
freqs[i][x[i]] += 1
except:
try:
freqs[i][x[i]] = 1
except:
freqs[i] = {}
freqs[i][x[i]] = 1
if x[i] == '-':
gapcount += 1
b['-'] = gapcount/(seqcount * len(freqs.keys()))
entropy = []
for i in sorted(freqs.keys()):
q = []
for l in one2number:
try:
q.append(np.log( freqs[i][l]/(b[l] * seqcount)))
q.append(freqs[i][l]/(b[l] * seqcount) * np.log( freqs[i][l]/(b[l] * seqcount)))
entropy.append(freqs[i][l]/(b[l] * seqcount) * np.log( freqs[i][l]/(b[l] * seqcount)))
except Exception as e:
q.append(np.log( 0.1/(b[l] * seqcount)))
q.append(0)
entropy.append(0)
pssm[i+1] = q
return (pssm, np.mean(entropy), [np.min(coverage), np.max(coverage), np.mean(coverage), np.median(coverage)])
def parseStats(f):
stats = []
ff = open(f).readlines()
if len(ff) != 6:
sys.stderr.write(f + ' has incorrect format!\n')
return [-1, -1, -1, -1, -1, -1]
for l in ff:
stats.append(float(l.split()[-1]))
return stats
def parseContacts(f):
contacts = set()
for l in open(f):
x = l.split()
if len(x) != 3:
sys.stderr.write('Incorrect format for ' + f)
sys.exit(1)
if float(x[-1]) < 8:
contacts.add( (int(x[0]), int(x[1])) )
return contacts
def predict(dir, X):
if not os.path.exists(dir + '/tree.list'):
sys.stderr.write('Directory {:s} does not contain proper random forest!\n'.format(dir))
sys.exit(0)
treelist = open(dir + '/tree.list').read().strip().split('\n')
random.shuffle(treelist)
treelist = treelist[:int(len(treelist)*treefraction)]
allcount = len(treelist)
trees = []
count = 0
pool = Pool()
sharedCount.value = len(X[0])
sharedPairs.value = len(X)
flatX = [item for sublist in X for item in sublist]
sharedX[:len(flatX)] = flatX
q = []
for t in treelist:
sys.stderr.write('\rLoading: [' + '#' * (80*count/allcount) + ' ' * (80- (80*count)/allcount) + ']')
# trees.append((joblib.load(dir + '/' + t.split('/')[-1]), X))
trees.append(joblib.load(dir + '/' + t.split('/')[-1]))
count += 1
print 'Chunks', len(trees)
sys.stderr.write('\n')
# predictions = np.zeros(len(X))
predictions = [0] * len(X)
allcount = len(trees)
ccc = 0.
loadtime = 0.
start = time.time()
divider = lambda p: p/count
# predictF = lambda t: predict_tree(t, X)
# rrr = map(predictF, trees)
rrr = pool.map(predict_tree, trees)
pool.close()
predictions = []
sys.stderr.write('Summing\n')
for i in range(len(rrr[0])):
s = 0.
for r in rrr:
s+= r[i]
predictions.append(s)
predictions = map(lambda p: p/count, predictions)
return predictions
def predict_tree(tree):
r = []
sys.stderr.write('.')
sys.stderr.flush()
for i in range(sharedPairs.value):
q = sharedX[i*sharedCount.value:(i+1)*sharedCount.value]
v = [0,0]
i = 0
while i >= 0:
if q[tree[0][i]] <= tree[1][i]:
j = tree[2][i]
else:
j = tree[3][i]
if j < 0:
v = tree[4][i][0]
i = j
r.append(float(v[1])/sum(v))
return r
if __name__ == '__main__':
firststart = time.time()
sharedCount = Value('I', lock=False)
sharedPairs = Value('I', lock=False)
sharedX = Array('f', 1000*1000*2000, lock=False)
Y = []
maxres = -1
outfile = files[9]
if os.path.exists(outfile):
pass
selected = set()
contacts = {}
X = []
sys.stderr.write('Doing ' + outfile + '\n')
accessibility = parseNetSurfP(files[3])
SSdict = parsePSIPRED(files[4])
stats = parseStats(files[5])
pssm = parsePSSM(files[6])
entropy = pssm[1]
coverage = pssm[2]
pssm = pssm[0]
selected = set()
for index in range(3):
contacts[index] = {}
d = files[index]
r = []
if not os.path.exists(d):
sys.stderr.write(d + ' does not exist!\n')
continue
infile = open(d).readlines()
for m in infile:
if d.find('gdca') > -1:
x = m.split()
c = 2
elif d.find('.plm') > -1:
x = m.split(',')
if len(x) != 3:
print d + ' has wrong format!'
sys.exit(1)
else:
x = m.split()
if len(x) < 3 or x[2] != '0' or x[3] != '8':
continue
c = -1
if len(x) < 3:
continue
aa1 = int(x[0])
aa2 = int(x[1])
if aa1 > maxres:
maxres = aa1
if aa2 > maxres:
maxres = aa2
if x[c].find('nan') > -1:
score = -3
else:
score = float(x[c])
contacts[index][(aa1, aa2)] = score
contacts[index][(aa2, aa1)] = score
if not aa2 > aa1:
continue
selected.add((aa1,aa2))
clist = []
for c in contacts[0].keys():
q = [ c ]
for i in contacts.keys():
try:
q.append( contacts[i][c] )
except:
q.append( -3 )
clist.append(q)
selected2 = set()
for i in contacts.keys():
clist.sort(key = lambda x: -x[i+1])
counter = -1
c = 0
while counter < maxres:
j = clist[c]
selected2.add(j[0])
c+=1
if abs(j[0][0] - j[0][1]) > 4:
counter += 1
maxscores = []
meantop = []
stdtop = []
for index in range(3):
maxscores.append(max(contacts[index].values()))
q = []
for s in list(selected2):
try:
q.append(contacts[index][s])
except:
pass
meantop.append(np.mean(q))
stdtop.append(np.std(q))
selected = list(selected)
selected.sort()
lastseeny = -1
X = []
Y = []
sys.stderr.write('Reading in data\n')
sys.stderr.flush()
count = 0
allcount = len(selected)
start = time.time()
for s in selected:
count += 1
if count % 100 == 0:
sys.stderr.write('\rProgress: [' + '#' * (80*count/allcount) + ' ' * (80*(allcount-count)/allcount) + ']')
now = time.time()
sys.stderr.write('Time remaining: {:7.1f}s'.format( (allcount-count) * (now-start)/count ) )
q = []
q.append(abs(s[0]-s[1]))
for ss in stats:
q.append(ss)
q.append(entropy)
q.extend(coverage)
q.append(maxscores[0])
q.append(maxscores[1])
q.append(maxscores[2])
for i in range(-5, 6):
for j in range(-5, 6):
for index in range(3):
try:
q.append(contacts[index][(s[0]+i, s[1]+j)])
q.append((contacts[index][(s[0]+i, s[1]+j)] - meantop[index])/stdtop[index])
except:
q.append(0)
q.append(0)
for i in range(-4, 5):
try:
q.extend(SSdict[s[0]+i] )
except:
q.extend([0,0,0])
for i in range(-4, 5):
try:
q.extend(SSdict[s[1]+i] )
except:
q.extend([0,0,0])
for i in range(-4, 5):
try:
q.extend(accessibility[s[0]+i] )
except:
q.extend([0,0,0,0,0])
for i in range(-4, 5):
try:
q.extend(accessibility[s[1]+i] )
except:
q.extend([0,0,0,0,0])
q.extend(pssm[s[0]])
q.extend(pssm[s[1]])
X.append(q)
sys.stderr.write('\n')
sys.stderr.flush()
# first layer
sys.stderr.write('\nPredicting base layer:\n')
X = tuple(X)
p = predict(forestlocation.format(0), X)
of = open(outfile + '.l0', 'w')
previouslayer = {}
for t in range(len(p)):
of.write('{:d} {:d} {:7.5f}\n'.format(selected[t][0], selected[t][1], p[t]))
try:
previouslayer[selected[t][0]][selected[t][1]] = p[t]
except:
previouslayer[selected[t][0]] = {}
previouslayer[selected[t][0]][selected[t][1]] = p[t]
of.close()
Xp = X
Yp = selected
for layer in range(1,6):
X = []
sys.stderr.write('\nPredicting convolution layer {:d}:\n'.format(layer))
for p in range(len(Xp)):
y = Yp[p]
q = list(Xp[p])
for i in range(-5,6):
for j in range(-5,6):
try:
q.append(previouslayer[y[0]+i][y[1] + j])
except:
q.append(-3)
X.append(q)
X = tuple(X)
p = predict(forestlocation.format(layer), X)
previouslayer = {}
of = open(outfile + '.l{:d}'.format(layer), 'w')
for t in range(len(p)):
of.write('{:d} {:d} {:7.5f}\n'.format(Yp[t][0], Yp[t][1], p[t]))
try:
previouslayer[Yp[t][0]][Yp[t][1]] = p[t]
except:
previouslayer[Yp[t][0]] = {}
previouslayer[Yp[t][0]][Yp[t][1]] = p[t]
of.close()
sys.stderr.write('\n\nSuccesfully completed in {:7.1f} seconds\n'.format(time.time() - firststart))