https://github.com/fenderglass/Ragout
Revision 06697b66767fb24a57f230a25305d0eef12b75af authored by fenderglass on 06 August 2013, 13:00:27 UTC, committed by fenderglass on 06 August 2013, 13:00:27 UTC
1 parent 0ad2623
Tip revision: 06697b66767fb24a57f230a25305d0eef12b75af authored by fenderglass on 06 August 2013, 13:00:27 UTC
estimation of distances between contigs
estimation of distances between contigs
Tip revision: 06697b6
refass.py
#!/usr/bin/env python
import sys
import math
import graph_tools
from collections import namedtuple, defaultdict
from itertools import combinations
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
Edge = namedtuple("Edge", ["vertex", "color", "distance"])
SyntenyBlock = namedtuple("SyntenyBlock", ["seq", "chr_id", "strand", "id", "start", "end", "chr_num"])
Permutation = namedtuple("Permutation", ["chr_id", "chr_num", "blocks"])
Connection = namedtuple("Connection", ["start", "end", "distance"])
class Node:
def __init__(self):
self.edges = []
class Contig:
def __init__(self, name):
self.name = name
self.sign = 1
self.blocks = []
class DumbContig(Contig):
def __init__(self, length):
Contig.__init__(self, "")
self.length = length
class Scaffold:
def __init__(self, left, right, contigs):
self.left = left
self.right = right
self.contigs = contigs
def parse_permutations_file(filename):
fin = open(filename, "r")
contigs = []
permutations = []
contig_name = None
ref_name = None
ref_num = 0
for line in fin:
if line.startswith(">"):
if line.startswith(">contig"):
contig_name = line.strip()[1:]
else:
ref_name = line.strip()[1:]
continue
blocks = line.strip().split(" ")[0:-1]
#contig
if contig_name:
contig = Contig(contig_name)
contig.blocks = map(int, blocks)
contigs.append(contig)
#reference
else:
permutations.append(Permutation(chr_id=ref_name, chr_num=ref_num,
blocks=map(int, blocks)))
return (permutations, contigs)
def parse_coords_file(blocks_file):
group = [[]]
num_seq_id = dict()
seq_id_num = dict()
line = [l.strip() for l in open(blocks_file) if l.strip()]
for l in line:
if l[0] == '-':
group.append([])
else:
group[-1].append(l)
for l in group[0][1:]:
l = l.split()
num_seq_id[l[0]] = l[2]
seq_id_num[l[2]] = int(l[0])
ret = dict()
for g in [g for g in group[1:] if g]:
block_id = int(g[0].split()[1][1:])
ret[block_id] = []
for l in g[2:]:
l = l.split()
chr_id = num_seq_id[l[0]]
start = int(l[2])
end = int(l[3])
chr_num = int(l[0]) - 1 #!!
ret[block_id].append(SyntenyBlock(seq='', chr_id=chr_id, strand=l[1],
id=block_id, start=start, end=end, chr_num=chr_num))
return (ret, seq_id_num)
def get_blocks_distance(left_block, right_block, ref_num, blocks_coord):
left_instances = filter(lambda b: b.chr_num == ref_num, blocks_coord[left_block])
right_instances = filter(lambda b: b.chr_num == ref_num, blocks_coord[right_block])
#print len(left_instances), len(right_instances), right_instances
assert len(left_instances) == len(right_instances) == 1
if left_instances[0].strand == "+":
left = left_instances[0].end
else:
left = left_instances[0].start
if right_instances[0].strand == "+":
right = right_instances[0].start
else:
right = right_instances[0].end
assert right >= left
return right - left - 1
def build_graph(permutations, blocks_coords):
#find duplications
duplications = set()
for perm in permutations:
current = set()
for block in perm.blocks:
if abs(block) in current:
duplications.add(abs(block))
current.add(abs(block))
print duplications
graph = defaultdict(Node)
color = 0
for perm in permutations:
prev = 0
while abs(perm.blocks[prev]) in duplications:
prev += 1
cur = prev + 1
while cur < len(perm.blocks):
while abs(perm.blocks[cur]) in duplications:
cur += 1
left_block = perm.blocks[prev]
right_block = perm.blocks[cur]
dist = get_blocks_distance(abs(left_block), abs(right_block), color, blocks_coords)
graph[-left_block].edges.append(Edge(right_block, color, dist))
graph[right_block].edges.append(Edge(-left_block, color, dist))
prev = cur
cur += 1
color += 1
return graph
def build_contig_index(contigs):
index = defaultdict(list)
for i, c in enumerate(contigs):
for block in c.blocks:
index[abs(block)].append(i)
return index
def sign(val):
return math.copysign(1, val)
def mean(vals_list):
assert len(vals_list) > 0
return sum(vals_list) / len(vals_list)
def extend_scaffolds(contigs, connections):
contig_index = build_contig_index(contigs)
scaffolds = []
visited = set()
def extend_scaffold(contig):
visited.add(contig)
scf = Scaffold(contig.blocks[0], contig.blocks[-1], [contig])
scaffolds.append(scf)
#go right
while scf.right in connections:
adjacent = connections[scf.right].end
distance = connections[scf.right].distance
assert len(contig_index[abs(adjacent)]) == 1
# print "alarm!", len(contig_index[adjacent])
# break
contig = contigs[contig_index[abs(adjacent)][0]]
if contig in visited:
break
if contig.blocks[0] == adjacent:
scf.contigs.append(DumbContig(distance))
scf.contigs.append(contig)
scf.right = contig.blocks[-1]
visited.add(contig)
continue
if -contig.blocks[-1] == adjacent:
scf.contigs.append(DumbContig(distance))
scf.contigs.append(contig)
scf.contigs[-1].sign = -1
scf.right = -contig.blocks[0]
visited.add(contig)
continue
break
#go left
while -scf.left in connections:
adjacent = -connections[-scf.left].end
distance = connections[-scf.left].distance
assert len(contig_index[abs(adjacent)]) == 1
# print "alarm!", len(contig_index[adjacent])
# break
contig = contigs[contig_index[abs(adjacent)][0]]
if contig in visited:
break
if contig.blocks[-1] == adjacent:
scf.contigs.insert(0, DumbContig(distance))
scf.contigs.insert(0, contig)
scf.left = contig.blocks[0]
visited.add(contig)
continue
if -contig.blocks[0] == adjacent:
scf.contigs.insert(0, DumbContig(distance))
scf.contigs.insert(0, contig)
scf.contigs[0].sign = -1
scf.left = -contig.blocks[-1]
visited.add(contig)
continue
break
for contig in contigs:
if contig not in visited:
extend_scaffold(contig)
return scaffolds
def get_component_of(connected_comps, vertex):
for con in connected_comps:
if vertex in con:
return con
return None
def case_on_vs_one(graph, component, connected_comps, contig_index, num_ref):
"""
a -- a
"""
MIN_REF_THRESHOLD = 2 #TODO: think about it
if len(component) != 2:
return None
num_edges = len(graph[component[0]].edges)
if num_edges not in range(MIN_REF_THRESHOLD, num_ref + 1):
return None
#print num_edges
for fst, snd in [(0, 1), (1, 0)]:
if abs(component[fst]) in contig_index and abs(component[snd]) not in contig_index:
pair_comp = get_component_of(connected_comps, -component[snd])
pair_id = pair_comp.index(-component[snd])
other_id = abs(1 - pair_id)
if pair_comp[other_id] in contigs:
print "indel found!"
return Connection(component[fst], pair_comp[other_id], None)
if abs(component[0]) in contig_index and abs(component[1]) in contig_index:
start = component[0]
end = component[1]
distance = mean(map(lambda e:e.distance, graph[start].edges))
return Connection(start, end, distance)
return None
def case_indel(graph, component, connected_comps, contig_index, num_ref):
"""
a -b
| \ |
b c
"""
if len(component) != 4:
return None
found = False
for v1, v2 in combinations(component, 2):
if v1 == -v2:
found = True
similar = [v1, v2]
different = filter(lambda v: v != v1 and v != v2, component)
if not found:
return None
#TODO: check graph structure
if abs(similar[0]) in contig_index:
print "deletion in some references"
connections = []
for s in similar:
distance = mean(map(lambda e : e.distance, graph[s].edges))
connections.append(Connection(s, graph[s].edges[0].vertex, distance))
return connections
else:
print "deletion in assembly and (possibly) references"
edges = filter(lambda e : e.vertex == different[1], graph[different[0]].edges)
distance = mean(map(lambda e : e.distance, edges))
#print distance
return [Connection(different[0], different[1], distance)]
def simple_connections(graph, connected_comps, contigs, num_ref):
connections = {}
contig_index = build_contig_index(contigs)
for component in connected_comps:
conn = case_on_vs_one(graph, component, conected_comps, contig_index, num_ref)
if conn is not None:
connections[-conn.start] = Connection(-conn.start, conn.end, conn.distance)
connections[-conn.end] = Connection(-conn.end, conn.start, conn.distance)
conn = case_indel(graph, component, conected_comps, contig_index, num_ref)
if conn is not None:
for c in conn:
connections[-c.start] = Connection(-c.start, c.end, c.distance)
connections[-c.end] = Connection(-c.end, c.start, c.distance)
print "connections infered:", len(connections)
return connections
def get_scaffolds(contigs, connections):
scaffolds = extend_scaffolds(contigs, connections)
scaffolds = filter(lambda s: len(s.contigs) > 1, scaffolds)
for scf in scaffolds:
contigs = filter(lambda c : not isinstance(c, DumbContig), scf.contigs)
for contig in contigs:
if contig.sign > 0:
print contig.blocks,
else:
print map(lambda b: -b, contig.blocks)[::-1],
print ""
return scaffolds
def output_scaffolds(input_contigs, scaffolds, out_file, write_contigs=False):
contigs = SeqIO.parse(input_contigs, "fasta")
out_stream = open(out_file, "w")
queue = {}
for rec in contigs:
queue[rec.id] = rec.seq
counter = 0
for scf in scaffolds:
scf_seq = Seq("")
buffer = ""
for i, contig in enumerate(scf.contigs):
if isinstance(contig, DumbContig):
buffer = "N" * contig.length
continue
cont_seq = queue[contig.name]
del queue[contig.name]
if contig.sign < 0:
cont_seq = cont_seq.reverse_complement()
if i > 0:
#check for overlapping
overlap = False
for window in xrange(5, 100):
if str(scf_seq)[-window:] == str(cont_seq)[0:window]:
assert overlap == False
cont_seq = cont_seq[window:]
overlap = True
if not overlap:
scf_seq += buffer
buffer = ""
scf_seq += cont_seq
name = "scaffold{0}".format(counter)
counter += 1
SeqIO.write(SeqRecord(scf_seq, id=name, description=""), out_stream, "fasta")
if write_contigs:
for h, seq in queue.iteritems():
SeqIO.write(SeqRecord(seq, id=h, description=""), out_stream, "fasta")
if __name__ == "__main__":
if len(sys.argv) < 3:
print "bg.py permutations contigs"
sys.exit(1)
blocks_coords, seqid = parse_coords_file("data/blocks_coords.txt")
#for r, rr in seqid.iteritems():
# print r, rr
permutations, contigs = parse_permutations_file(sys.argv[1])
num_references = len(permutations)
graph = build_graph(permutations, blocks_coords)
conected_comps = graph_tools.get_connected_components(graph)
connections = simple_connections(graph, conected_comps, contigs, num_references)
scaffolds = get_scaffolds(contigs, connections)
output_scaffolds(sys.argv[2], scaffolds, "scaffolds.fasta")
graph_tools.output_graph(graph, open("bg.dot", "w"))
Computing file changes ...