first pass cleanup of cistematic/genomes; change bamPreprocessing
[erange.git] / rnapath / RNAPATH.py
1 import sys
2 import optparse
3 import string
4 from numpy import zeros, int16
5 from erange.commoncode import getConfigParser, getConfigOption, getConfigIntOption
6
7 versionString = "RNAPATH: version 0.96"
8 print versionString
9
10
11 def compNT(nt):
12     """ returns the complementary basepair to base nt
13     """
14     compDict = { "A": "T",
15                  "T": "A",
16                  "G": "C",
17                  "C": "G",
18                  "S": "S",
19                  "W": "W",
20                  "R": "Y",
21                  "Y": "R",
22                  "M": "K",
23                  "K": "M",
24                  "H": "D",
25                  "D": "H",
26                  "B": "V",
27                  "V": "B",
28                  "N": "N",
29                  "a": "t",
30                  "t": "a",
31                  "g": "c",
32                  "c": "g",
33                  "n": "n",
34                  "z": "z"
35     }
36
37     return compDict.get(nt, "N")
38
39
40 def complement(sequence, length=-1):
41     """ returns the complement of the sequence.
42     """
43     newSeq = ""
44     
45     seqLength = len(sequence)
46     
47     if length == seqLength or length < 0:
48         seqList = list(sequence)
49         seqList.reverse()
50         return "".join(map(compNT, seqList))
51
52     #TODO: this seems to want to deal with case where length is more than
53     # sequence length except that a negative index on a sequence is fine
54     # index will only be overrun if length is negative but that case is
55     # handled above
56     for index in range(seqLength - 1,seqLength - length - 1, -1):
57         try:
58             newSeq += compNT(sequence[index])
59         except:
60             newSeq += "N"
61
62     return newSeq
63
64
65 def main(argv=None):
66     if not argv:
67         argv = sys.argv
68
69     usage = "python %prog incontigfile distalPairs outpathfile outcontigfile [--prefix string] [--overlap bp]"
70
71     parser = getParser(usage)
72     (options, args) = parser.parse_args(argv[1:])
73
74     if len(args) < 4:
75         print usage
76         sys.exit(0)
77
78     incontigfilename = args[0]
79     distalPairsfile = args[1]
80     outpathfilename = args[2]
81     outcontigfilename = args[3]
82
83     rnaPath(incontigfilename, distalPairsfile, outpathfilename,
84             outcontigfilename, options.pathPrefix, options.overlap)
85
86
87 def getParser(usage):
88     parser = optparse.OptionParser(usage=usage)
89     parser.add_option("--prefix", dest="pathPrefix")
90     parser.add_option("--overlap", type="int", dest="overlap")
91
92     configParser = getConfigParser()
93     section = "RNAPATH"
94     pathPrefix = getConfigOption(configParser, section, "pathPrefix", "RNAPATH")
95     overlap = getConfigIntOption(configParser, section, "overlap", 30)
96
97     parser.set_defaults(pathPrefix=pathPrefix, overlap=overlap)
98
99     return parser
100
101
102 def rnaPath(incontigfilename, distalPairsfile, outpathfilename,
103             outcontigfilename, pathPrefix="RNAPATH", overlap=30):
104
105     outpathfile = open(outpathfilename, "w")
106     
107     outheader = "#settings: %s" % " ".join(sys.argv)
108     print outheader
109     print >> outpathfile, outheader
110    
111     contigNum, nameList, contigDict, origSize = getContigsFromFile(incontigfilename)
112     halfSize = calculateN50(origSize)
113     print "building the adjacency graph"
114     pathList, edgeSenseDict, visitedDict = getPath(contigNum, distalPairsfile, nameList)
115
116     print "found %d paths" % len(pathList)            
117
118     newSizeList = []
119     pathID = 0
120     outcontigfile = open(outcontigfilename, "w")
121     for path in pathList:
122         pathID += 1
123         outpathfile.write("chr%s%d: %s\n" % (pathPrefix, pathID, str(path))) 
124         vertexNameList = []
125         for vertex in path:
126             vertexNameList.append(nameList[vertex])
127             pathDescription = string.join(vertexNameList, ",")
128
129         print >> outpathfile, pathDescription
130         currentVertex = path[0]
131         currentSense = "+"
132         assemblyList = currentVertex
133         sequence = contigDict[currentVertex]
134         for nextVertex in path[1:]:
135             if (currentVertex, nextVertex) in edgeSenseDict:
136                 senseList = edgeSenseDict[currentVertex, nextVertex]
137                 FR = senseList.count(("+", "-"))
138                 RF = senseList.count(("-", "+"))
139             else:
140                 senseList = edgeSenseDict[nextVertex, currentVertex]
141                 # flip
142                 FR = senseList.count(("-", "+"))
143                 RF = senseList.count(("+", "-"))
144
145             FF = senseList.count(("+", "+"))
146             RR = senseList.count(("-", "-"))
147             if currentSense == "-":
148                 # we had flipped the upstream piece! Must flip again
149                 temp1 = FR
150                 temp2 = FF
151                 FR = RR
152                 FF = RF
153                 RR = temp1
154                 RF = temp2
155
156             if FR >= FF and FR >= RR and FR >= RF:
157                 # we have FR - leave alone
158                 sense1 = "+"
159                 sense2 = "-"
160                 assemblyList = ((assemblyList, "+"), (nextVertex, "+"))
161                 seqleft = sequence[-20:]
162                 seqright = contigDict[nextVertex][:overlap]
163                 if seqleft in seqright:
164                     pos = seqright.index(seqleft)
165                     offset = pos + 20
166                     outstring = "stitching %d and %d using %d overlap" % (currentVertex, nextVertex, offset)
167                     print outstring
168                     print >> outpathfile, outstring
169                     sequence += contigDict[nextVertex][offset:]
170                 else:
171                     sequence += "NN" + contigDict[nextVertex]
172
173                 currentSense = "+"
174             elif FF >= RR and FF >= RF:
175                 # we have FF - flip seqright
176                 sense1 = "+"
177                 sense2 = "+"
178                 assemblyList = ((assemblyList, "+"), (nextVertex, "-"))
179                 seqleft = sequence[-20:]
180                 seqright = complement(contigDict[nextVertex])[:overlap]
181                 if seqleft in seqright:
182                     pos = seqright.index(seqleft)
183                     offset = pos + 20
184                     outstring = "stitching %d and %d using %d overlap" % (nextVertex, currentVertex, offset)
185                     print outstring
186                     print >> outpathfile, outstring
187                     sequence += complement(contigDict[nextVertex])[offset:]
188                 else:
189                     sequence += "NN" + complement(contigDict[nextVertex])
190
191                 currentSense = "-"
192             elif RR >= RF:
193                 # we have RR - flip seqleft
194                 sense1 = "-"
195                 sense2 = "-"
196                 assemblyList = ((assemblyList, "-"), (nextVertex, "+"))
197                 seqleft = complement(sequence)[:20]
198                 seqright = contigDict[nextVertex][:overlap]
199                 if seqleft in seqright:
200                     pos = seqright.index(seqleft)
201                     offset = pos + 20
202                     outstring = "stitching %d and %d using %d overlap" % (nextVertex, currentVertex, offset)
203                     print outstring
204                     print >> outpathfile, outstring
205                     sequence = complement(sequence) + contigDict[nextVertex][offset:]
206                 else:
207                     sequence = complement(sequence) + "NN" + contigDict[nextVertex]
208
209                 currentSense = "+"
210             else:
211                 # we have RF - flip both
212                 sense1 = "-"
213                 sense2 = "+"
214                 assemblyList = ((assemblyList, "-"), (nextVertex, "-"))
215                 seqleft = complement(sequence)[-20:]
216                 seqright = complement(contigDict[nextVertex])[:overlap]
217                 if seqleft in seqright:
218                     pos = seqright.index(seqleft)
219                     offset = pos + 20
220                     outstring = "stitching %d and %d using %d overlap" % (nextVertex, currentVertex, offset)
221                     print outstring
222                     print >> outpathfile, outstring
223                     sequence = complement(sequence) + complement(contigDict[nextVertex])[offset:]
224                 else:
225                     sequence = complement(sequence) + "NN" + complement(contigDict[nextVertex])
226
227                 currentSense = "-"
228
229             outstring = "(%d, %d): FF %d RR %d RF %d FR %d : %s %s\t%s" % (currentVertex, nextVertex, FF, RR, RF, FR, sense1, sense2, str(assemblyList))
230             print outstring
231             print >> outpathfile, outstring
232             currentVertex = nextVertex
233
234         outcontigfile.write(">chr%s%d %dbp %s | %s\n%s\n" % (pathPrefix, pathID, len(sequence), pathDescription, str(assemblyList), sequence))
235         newSizeList.append(len(sequence))
236
237     for vertex in contigDict:
238         if vertex in visitedDict:
239             continue
240
241         newSizeList.append(len(contigDict[vertex]))
242         outcontigfile.write(">%s\n%s\n" % (nameList[vertex], contigDict[vertex]))
243
244     calculateN50(newSizeList, referenceMean=halfSize)
245
246
247 def calculateN50(sizeList, referenceMean=None):
248     if referenceMean is None:
249         totalSize = sum(sizeList)
250         referenceMean = totalSize / 2
251
252     sizeList.sort()
253     sizeList.reverse()
254     currentTotalLength = 0
255     for size in sizeList:
256         if currentTotalLength + size > referenceMean:
257             print "#contigs", len(sizeList)
258             print "N50", size
259             break
260
261         currentTotalLength += size
262
263     print sizeList[:50]
264
265     return referenceMean
266
267
268 def getContigsFromFile(contigFileName):
269     nameList = []
270     origSize = []
271     contigNum = 0
272     currentChrom = ""
273     seq = ""
274     contigDict = {}
275
276     try:
277         incontigfile = open(contigFileName)
278     except IOError:
279         print "Error opening contig file: %s" % contigFileName
280         return contigNum, nameList, contigDict, origSize
281
282     for line in incontigfile:
283         if ">" in line:
284             if currentChrom !="":
285                 nameList.append(currentChrom)
286                 contigDict[contigNum] = seq
287                 origSize.append(len(seq))
288                 contigNum += 1
289
290             currentChrom = line.strip().split()[0][1:]
291             seq = ""
292         else:
293             seq += line.strip()
294
295     incontigfile.close()
296
297     return contigNum, nameList, contigDict, origSize
298
299
300 def getPath(contigNum, distalPairsfile, nameList):
301     edgeMatrix = EdgeMatrix(contigNum)
302
303     print len(edgeMatrix.edgeArray)
304     try:
305         print len(edgeMatrix.edgeArray[50])
306     except IndexError:
307         pass
308
309     print "processing distal pairs"
310     verticesWithEdges, vertexEdges, notSoloDict, edgeSenseDict = processDistalPairsFile(distalPairsfile, edgeMatrix, nameList)
311
312     willVisitList = verticesWithEdges.keys()
313     willVisitList.sort()
314     print "visiting %d vertices" % len(willVisitList)
315
316     print "cleaning up graph of edges with weight 1"
317     verticesToDelete = []
318     for rindex in willVisitList:
319         if rindex not in notSoloDict:
320             cindex = vertexEdges[rindex][0]
321             edgeMatrix.edgeArray[rindex][cindex] = 0
322             edgeMatrix.edgeArray[cindex][rindex] = 0
323             verticesToDelete.append(rindex)
324
325     for vertex in verticesToDelete:
326         willVisitList.remove(vertex)
327
328     print "%d 1-edges zeroed out" % len(verticesToDelete)
329
330     zeroedEdge = 0
331     print "visiting %d vertices" % len(willVisitList)
332
333     leafList = []
334     print "picking top 2 edges per vertex - zero out others"
335     for rindex in willVisitList:
336         vertices = vertexEdges[rindex]
337         rEdges = []
338         for avertex in vertices:
339             if avertex in willVisitList:
340                 rEdges.append((edgeMatrix.edgeArray[rindex][avertex], avertex))
341
342         if len(rEdges) > 2:
343             rEdges.sort()
344             rEdges.reverse()
345             zeroedEdge += len(rEdges[2:])
346             for (weight, cindex) in rEdges[2:]:
347                 edgeMatrix.edgeArray[rindex][cindex] = 0
348                 edgeMatrix.edgeArray[cindex][rindex] = 0
349         elif len(rEdges) == 1:
350             if edgeMatrix.edgeArray[rindex][rEdges[0][1]] > 1:
351                 leafList.append(rindex)
352
353     print "zeroed out %d lower-weight edges at vertices with degree > 2" % zeroedEdge
354     pathList, visitedDict = traverseGraph(leafList, edgeMatrix)
355
356     return pathList, edgeSenseDict, visitedDict
357
358
359 def traverseGraph(leafList, edgeMatrix):
360     pathList = []
361     visitedDict = {}
362     leafList.sort()
363     print "traveling through the graph"
364     for rindex in leafList:
365         if visitedDict.has_key(rindex):
366             pass
367         else:
368             path = edgeMatrix.visitLink(rindex)
369             if len(path) > 1:
370                 for vertex in path:
371                     visitedDict[vertex] = ""
372
373                 print path
374                 pathList.append(path)
375
376     return pathList, visitedDict
377
378
379 def processDistalPairsFile(distalPairsfilename, edgeMatrix, nameList):
380     contigToRowLookup = {}
381     verticesWithEdges = {}
382     vertexEdges = {}
383     notSoloDict = {}
384     edgeSenseDict = {}
385
386     distalPairs = open(distalPairsfilename)
387     for line in distalPairs:
388         if line[0] == "#":
389             continue
390
391         fields = line.strip().split()
392         contA = "chr%s" % fields[1]
393         try:
394             contig1 = contigToRowLookup[contA]
395         except KeyError:
396             try:
397                 contig1 = nameList.index(contA)
398                 contigToRowLookup[contA] = contig1
399             except ValueError:
400                 print "problem with end1: ", line
401                 continue
402
403         sense1 = fields[3]
404
405         contB = "chr%s" % fields[4]
406         try:
407             contig2 = contigToRowLookup[contB]
408         except KeyError:
409             try:
410                 contig2 = nameList.index(contB)
411                 contigToRowLookup[contB] = contig2
412             except ValueError:
413                 print "problem with end2: ", line
414                 continue
415
416         sense2 = fields[6]
417
418         edgeMatrix.edgeArray[contig1][contig2] += 1
419         edgeMatrix.edgeArray[contig2][contig1] += 1
420         verticesWithEdges[contig1] = ""
421         verticesWithEdges[contig2] = ""
422         if (contig1, contig2) in edgeSenseDict:
423             edgeSenseDict[contig1, contig2].append((sense1, sense2))
424         elif (contig2, contig1) in edgeSenseDict:
425             edgeSenseDict[contig2, contig1].append((sense2, sense1))
426         else:
427             edgeSenseDict[contig1, contig2] = [(sense1, sense2)]
428
429         if contig1 in vertexEdges:
430             if contig2 not in vertexEdges[contig1]:
431                 vertexEdges[contig1].append(contig2)
432         else:
433             vertexEdges[contig1] = [contig2]
434
435         if contig2 in vertexEdges:
436             if contig1 not in vertexEdges[contig2]:
437                 vertexEdges[contig2].append(contig1)
438         else:
439             vertexEdges[contig2] = [contig1]
440
441         if edgeMatrix.edgeArray[contig1][contig2] > 1:
442             notSoloDict[contig1] = ""
443             notSoloDict[contig2] = ""
444
445     distalPairs.close()
446     
447     return verticesWithEdges, vertexEdges, notSoloDict, edgeSenseDict
448
449
450 class EdgeMatrix:
451     """ Describes a sparse matrix to hold edge data.
452     """
453
454     def __init__(self, dimension):
455         self.dimension = dimension
456         self.edgeArray = zeros((self.dimension, self.dimension), int16)
457
458
459     def visitLink(self, fromVertex, ignoreList=[]):
460         returnPath = [fromVertex]
461         toVertex = []
462         for toindex in xrange(self.dimension):
463             if self.edgeArray[fromVertex][toindex] > 1 and toindex not in ignoreList:
464                 toVertex.append(toindex)
465
466         for vertex in toVertex:
467             if sum(self.edgeArray[vertex]) == self.edgeArray[fromVertex][vertex]:
468                 self.edgeArray[fromVertex][vertex] = 0
469                 self.edgeArray[vertex][fromVertex] = 0
470                 return returnPath + [vertex]
471             else:
472                 self.edgeArray[fromVertex][vertex] = 0
473                 try:
474                     return returnPath + self.visitLink(vertex, returnPath)
475                 except IOError:
476                     return returnPath + [vertex]
477         return []
478
479
480 if __name__ == "__main__":
481     main(sys.argv)