#!/usr/bin/env python2.7 import joblib import sys import os, random, time import numpy as np import multiprocessing from multiprocessing.pool import ThreadPool from multiprocessing import Value, Array, Pool if len(sys.argv) != 11: print 'Usage: ' + sys.argv[0] + ' ' print 'That is:' print ' GaussDCA prediction' print ' plmDCA.jl predictions' print ' ML contact predictions' print ' NetSurf RSA' print ' SS file' print ' Alignment stats' print ' Alignment' print ' Forest location' print ' MaxDepth' print ' Outfile' sys.exit(1) files = sys.argv[1:] maxdepth = files[8] # maxdepth = -1 if maxdepth < 0: forestlocation = files[7] + '/tlayer{:d}' else: forestlocation = files[7] + '/tlayer{:d}-' + str(maxdepth) # maximum time per layer maxtime = pow(10,6) # fraction of trees to use (prediction time scales linearly with the number of trees, # while expected precision is roughly the same for values > 0.3 treedepth = 100 treefraction = 1 print forestlocation.format(0) if not os.path.exists(forestlocation.format(0)): forestlocation = os.path.dirname(os.path.realpath(__file__)) + '/tlayer{:d}' print forestlocation for i in range(5): abort = False if not os.path.exists(forestlocation.format(i) + '/tree.list'): sys.stderr.write('Forest data for layer {:d} is missing.\n'.format(i)) abort = True if abort: sys.exit(0) def parsePSIPRED(f): SSdict = {} try: x = open(f).read().split('\n') except: return SSdict for l in x: y = l.split() if len(y) != 6: continue i = int(y[0]) SSdict[i] = [float(y[3]), float(y[4]), float(y[5])] return SSdict def parseNetSurfP(f): netSurfdict = {} for l in open(f).readlines(): al = [] x = l.split() if l.find('#') == 0: continue if l[0] not in ['B', 'E']: y = ['E'] y.extend(x) x = y for y in [4,6,7, 8, 9]: al.append(float(x[y]) ) netSurfdict[ int(x[3] )] = al return netSurfdict def parsePSSM(alignment): pssm = {} one2number = 'ARNDCEQGHILKMFPSTWYV-' bi = [ 0.0825, 0.0553, 0.0406, 0.0545, 0.0137, 0.0393, 0.0675, 0.0707, 0.0227, 0.0595, 0.0966, 0.0584, 0.0242, 0.0386, 0.0470, 0.0657, 0.0534, 0.0108, 0.0292, 0.0687 ] b = {} for i in one2number[:-1]: b[i] = bi[one2number.find(i)] freqs = {} seqcount = 0. gapcount = 0 coverage = [] for l in open(alignment): if l.find('>') > -1: continue x = l.strip() if len(x) < 3: continue seqcount += 1 coverage.append( (len(x) - x.count('-'))/float(len(x))) for i in range(len(x)): try: freqs[i][x[i]] += 1 except: try: freqs[i][x[i]] = 1 except: freqs[i] = {} freqs[i][x[i]] = 1 if x[i] == '-': gapcount += 1 b['-'] = gapcount/(seqcount * len(freqs.keys())) entropy = [] for i in sorted(freqs.keys()): q = [] for l in one2number: try: q.append(np.log( freqs[i][l]/(b[l] * seqcount))) q.append(freqs[i][l]/(b[l] * seqcount) * np.log( freqs[i][l]/(b[l] * seqcount))) entropy.append(freqs[i][l]/(b[l] * seqcount) * np.log( freqs[i][l]/(b[l] * seqcount))) except Exception as e: q.append(np.log( 0.1/(b[l] * seqcount))) q.append(0) entropy.append(0) pssm[i+1] = q return (pssm, np.mean(entropy), [np.min(coverage), np.max(coverage), np.mean(coverage), np.median(coverage)]) def parseStats(f): stats = [] ff = open(f).readlines() if len(ff) != 6: sys.stderr.write(f + ' has incorrect format!\n') return [-1, -1, -1, -1, -1, -1] for l in ff: stats.append(float(l.split()[-1])) return stats def parseContacts(f): contacts = set() for l in open(f): x = l.split() if len(x) != 3: sys.stderr.write('Incorrect format for ' + f) sys.exit(1) if float(x[-1]) < 8: contacts.add( (int(x[0]), int(x[1])) ) return contacts def predict(dir, X): if not os.path.exists(dir + '/tree.list'): sys.stderr.write('Directory {:s} does not contain proper random forest!\n'.format(dir)) sys.exit(0) treelist = open(dir + '/tree.list').read().strip().split('\n') random.shuffle(treelist) treelist = treelist[:int(len(treelist)*treefraction)] allcount = len(treelist) trees = [] count = 0 pool = Pool() sharedCount.value = len(X[0]) sharedPairs.value = len(X) flatX = [item for sublist in X for item in sublist] sharedX[:len(flatX)] = flatX q = [] for t in treelist: sys.stderr.write('\rLoading: [' + '#' * (80*count/allcount) + ' ' * (80- (80*count)/allcount) + ']') # trees.append((joblib.load(dir + '/' + t.split('/')[-1]), X)) trees.append(joblib.load(dir + '/' + t.split('/')[-1])) count += 1 print 'Chunks', len(trees) sys.stderr.write('\n') # predictions = np.zeros(len(X)) predictions = [0] * len(X) allcount = len(trees) ccc = 0. loadtime = 0. start = time.time() divider = lambda p: p/count # predictF = lambda t: predict_tree(t, X) # rrr = map(predictF, trees) rrr = pool.map(predict_tree, trees) pool.close() predictions = [] sys.stderr.write('Summing\n') for i in range(len(rrr[0])): s = 0. for r in rrr: s+= r[i] predictions.append(s) predictions = map(lambda p: p/count, predictions) return predictions def predict_tree(tree): r = [] sys.stderr.write('.') sys.stderr.flush() for i in range(sharedPairs.value): q = sharedX[i*sharedCount.value:(i+1)*sharedCount.value] v = [0,0] i = 0 while i >= 0: if q[tree[0][i]] <= tree[1][i]: j = tree[2][i] else: j = tree[3][i] if j < 0: v = tree[4][i][0] i = j r.append(float(v[1])/sum(v)) return r if __name__ == '__main__': firststart = time.time() sharedCount = Value('I', lock=False) sharedPairs = Value('I', lock=False) sharedX = Array('f', 1000*1000*2000, lock=False) Y = [] maxres = -1 outfile = files[9] if os.path.exists(outfile): pass selected = set() contacts = {} X = [] sys.stderr.write('Doing ' + outfile + '\n') accessibility = parseNetSurfP(files[3]) SSdict = parsePSIPRED(files[4]) stats = parseStats(files[5]) pssm = parsePSSM(files[6]) entropy = pssm[1] coverage = pssm[2] pssm = pssm[0] selected = set() for index in range(3): contacts[index] = {} d = files[index] r = [] if not os.path.exists(d): sys.stderr.write(d + ' does not exist!\n') continue infile = open(d).readlines() for m in infile: if d.find('gdca') > -1: x = m.split() c = 2 elif d.find('.plm') > -1: x = m.split(',') if len(x) != 3: print d + ' has wrong format!' sys.exit(1) else: x = m.split() if len(x) < 3 or x[2] != '0' or x[3] != '8': continue c = -1 if len(x) < 3: continue aa1 = int(x[0]) aa2 = int(x[1]) if aa1 > maxres: maxres = aa1 if aa2 > maxres: maxres = aa2 if x[c].find('nan') > -1: score = -3 else: score = float(x[c]) contacts[index][(aa1, aa2)] = score contacts[index][(aa2, aa1)] = score if not aa2 > aa1: continue selected.add((aa1,aa2)) clist = [] for c in contacts[0].keys(): q = [ c ] for i in contacts.keys(): try: q.append( contacts[i][c] ) except: q.append( -3 ) clist.append(q) selected2 = set() for i in contacts.keys(): clist.sort(key = lambda x: -x[i+1]) counter = -1 c = 0 while counter < maxres: j = clist[c] selected2.add(j[0]) c+=1 if abs(j[0][0] - j[0][1]) > 4: counter += 1 maxscores = [] meantop = [] stdtop = [] for index in range(3): maxscores.append(max(contacts[index].values())) q = [] for s in list(selected2): try: q.append(contacts[index][s]) except: pass meantop.append(np.mean(q)) stdtop.append(np.std(q)) selected = list(selected) selected.sort() lastseeny = -1 X = [] Y = [] sys.stderr.write('Reading in data\n') sys.stderr.flush() count = 0 allcount = len(selected) start = time.time() for s in selected: count += 1 if count % 100 == 0: sys.stderr.write('\rProgress: [' + '#' * (80*count/allcount) + ' ' * (80*(allcount-count)/allcount) + ']') now = time.time() sys.stderr.write('Time remaining: {:7.1f}s'.format( (allcount-count) * (now-start)/count ) ) q = [] q.append(abs(s[0]-s[1])) for ss in stats: q.append(ss) q.append(entropy) q.extend(coverage) q.append(maxscores[0]) q.append(maxscores[1]) q.append(maxscores[2]) for i in range(-5, 6): for j in range(-5, 6): for index in range(3): try: q.append(contacts[index][(s[0]+i, s[1]+j)]) q.append((contacts[index][(s[0]+i, s[1]+j)] - meantop[index])/stdtop[index]) except: q.append(0) q.append(0) for i in range(-4, 5): try: q.extend(SSdict[s[0]+i] ) except: q.extend([0,0,0]) for i in range(-4, 5): try: q.extend(SSdict[s[1]+i] ) except: q.extend([0,0,0]) for i in range(-4, 5): try: q.extend(accessibility[s[0]+i] ) except: q.extend([0,0,0,0,0]) for i in range(-4, 5): try: q.extend(accessibility[s[1]+i] ) except: q.extend([0,0,0,0,0]) q.extend(pssm[s[0]]) q.extend(pssm[s[1]]) X.append(q) sys.stderr.write('\n') sys.stderr.flush() # first layer sys.stderr.write('\nPredicting base layer:\n') X = tuple(X) p = predict(forestlocation.format(0), X) of = open(outfile + '.l0', 'w') previouslayer = {} for t in range(len(p)): of.write('{:d} {:d} {:7.5f}\n'.format(selected[t][0], selected[t][1], p[t])) try: previouslayer[selected[t][0]][selected[t][1]] = p[t] except: previouslayer[selected[t][0]] = {} previouslayer[selected[t][0]][selected[t][1]] = p[t] of.close() Xp = X Yp = selected for layer in range(1,6): X = [] sys.stderr.write('\nPredicting convolution layer {:d}:\n'.format(layer)) for p in range(len(Xp)): y = Yp[p] q = list(Xp[p]) for i in range(-5,6): for j in range(-5,6): try: q.append(previouslayer[y[0]+i][y[1] + j]) except: q.append(-3) X.append(q) X = tuple(X) p = predict(forestlocation.format(layer), X) previouslayer = {} of = open(outfile + '.l{:d}'.format(layer), 'w') for t in range(len(p)): of.write('{:d} {:d} {:7.5f}\n'.format(Yp[t][0], Yp[t][1], p[t])) try: previouslayer[Yp[t][0]][Yp[t][1]] = p[t] except: previouslayer[Yp[t][0]] = {} previouslayer[Yp[t][0]][Yp[t][1]] = p[t] of.close() sys.stderr.write('\n\nSuccesfully completed in {:7.1f} seconds\n'.format(time.time() - firststart))