first pass cleanup of cistematic/genomes; change bamPreprocessing
[erange.git] / rnafarPairs.py
1 """ usage: python rnafarpairs.py genome goodfile rdsfile outfile [options]
2            looks at all chromosomes simultaneously: is both slow and takes up large amount of RAM
3 """
4 try:
5     import psyco
6     psyco.full()
7 except:
8     pass
9
10 import sys
11 import time
12 import optparse
13 import pysam
14 from commoncode import getGeneInfoDict, getGeneAnnotDict, getConfigParser, getConfigIntOption, getConfigBoolOption, isSpliceEntry
15
16
17 def main(argv=None):
18     if not argv:
19         argv = sys.argv
20
21     print "rnafarPairs: version 3.7"
22     usage = "usage: python %prog genome goodfile rdsfile outfile [options]"
23
24     parser = makeParser(usage)
25     (options, args) = parser.parse_args(argv[1:])
26
27     if len(args) < 4:
28         print usage
29         sys.exit(1)
30
31     genome = args[0]
32     goodfilename = args[1]
33     bamfilename = args[2]
34     outfilename = args[3]
35
36     bamfile = pysam.Samfile(bamfilename, "rb")
37
38     rnaFarPairs(genome, goodfilename, bamfile, outfilename, options.doVerbose, options.doCache, options.maxDist)
39
40
41 def makeParser(usage=""):
42     parser = optparse.OptionParser(usage=usage)
43     parser.add_option("--verbose", action="store_true", dest="doVerbose",
44                       help="verbose output")
45     parser.add_option("--cache", action="store_true", dest="doCache",
46                       help="use cache")
47     parser.add_option("--maxDist", type="int", dest="maxDist",
48                       help="maximum distance")
49
50     configParser = getConfigParser()
51     section = "rnafarPairs"
52     doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
53     doCache = getConfigBoolOption(configParser, section, "doCache", False)
54     maxDist = getConfigIntOption(configParser, section, "maxDist", 500000)
55
56     parser.set_defaults(doVerbose=doVerbose, doCache=doCache, maxDist=maxDist)
57
58     return parser
59
60
61 def rnaFarPairs(genome, goodfilename, bamfile, outfilename, doVerbose=False, doCache=False, maxDist=500000):
62     """ map all candidate regions that have paired ends overlapping with known genes
63     """
64
65     chromosomeList = [chrom for chrom in bamfile.references if chrom != "chrM"]
66     regions = {}
67     for chromosome in chromosomeList:
68         regions[chromosome] = []
69
70     goodDict = {}
71     goodfile = open(goodfilename)
72     for line in goodfile:
73         fields = line.split()
74         label = fields[0]
75         start = int(fields[2])
76         stop = int(fields[3])
77         goodDict[label] = line
78         regions[chromosome].append((start, stop, label))
79
80     goodfile.close()
81     if doVerbose:
82         print time.ctime()
83
84     distinct = 0
85     total = 0
86     outfile = open(outfilename,"w")
87     geneinfoDict = getGeneInfoDict(genome)
88     geneannotDict = getGeneAnnotDict(genome)
89     assigned = {}
90     farConnected = {}
91     for chromosome in chromosomeList:
92         print chromosome
93         regionList = regions[chromosome].sort()
94         uniqDict, pairCount = getUniqueReadIDFlags(bamfile, regionList, chromosome, maxDist)
95         if doVerbose:
96             print len(uniqDict), time.ctime()    
97
98         total += pairCount
99         for readID, readList in uniqDict.items():
100             flags = (readList[0]["flag"], readList[1]["flag"])
101             processed, distinctPairs = writeFarPairsToFile(flags, goodDict, genome, geneinfoDict, geneannotDict, outfile, assigned, farConnected)
102             total += processed
103             distinct += distinctPairs
104
105     entriesWritten = writeUnassignedEntriesToFile(farConnected, assigned, goodDict, outfile)
106     distinct += entriesWritten
107     outfile.write("#distinct: %d\ttotal: %d\n" % (distinct, total))
108     outfile.close()
109     print "distinct: %d\ttotal: %d" % (distinct, total)
110     print time.ctime()
111
112
113 def getUniqueReadIDFlags(bamfile, regions, chromosome, maxDist):
114     """ Returns dictionary of readsIDs with each entry consisting of a list of dictionaries of read start and read flag.
115         Only returns unique non-spliced read pairs matching the criteria given in processReads().
116     """
117     start = 1
118     readDict = {}
119     for regionstart, regionstop, regionname in regions:
120         for alignedread in bamfile.fetch(chromosome, start, regionstop):
121             if alignedread.opt("NH") == 1 and not isSpliceEntry(alignedread.cigar):
122                 if alignedread.pos >= regionstart:
123                     flag = regionname
124                 else:
125                     flag = alignedread.opt("ZG")
126
127                 try:
128                     readDict[alignedread.qname].append({"start": alignedread.pos, "flag": flag})
129                 except KeyError:
130                     readDict[alignedread.qname] = [{"start": alignedread.pos, "flag": flag}]
131
132         start = regionstop + 1
133
134     for alignedread in bamfile.fetch(chromosome, start):
135         if alignedread.opt("NH") == 1 and not isSpliceEntry(alignedread.cigar):
136             flag = alignedread.opt("ZG")
137
138             try:
139                 readDict[alignedread.qname].append({"start": alignedread.pos, "flag": flag})
140             except KeyError:
141                 readDict[alignedread.qname] = [{"start": alignedread.pos, "flag": flag}]
142
143     pairCount = len(readDict.keys())
144     for readID, readList in readDict.items():
145         if len(readList) != 2:
146             del readDict[readID]
147             pairCount -= 1
148         elif not processReads(readList, maxDist):
149             del readDict[readID]
150
151     return readDict, pairCount
152
153
154 def processReads(reads, maxDist=500000):
155     """ For a pair of readID's to be acceptable:
156             - flags must be different
157             - neither flag can be 'NM'
158             - the read starts have to be within maxDist
159     """
160     process = False
161     dist = abs(reads[0]["start"] - reads[1]["start"])
162     flag1 = reads[0]["flag"]
163     flag2 = reads[1]["flag"]
164
165     if flag1 != flag2 and flag1 != "NM" and flag2 != "NM" and dist < maxDist:
166         process = True
167
168     return process
169
170
171 def writeFarPairsToFile(flags, goodDict, genome, geneInfoDict, geneAnnotDict, outfile, assigned, farConnected):
172     """ Writes out the region information along with symbol and geneID for paired reads where one read
173         of the pair is in a known gene and the other is in a new region.  If both reads in the pair are
174         in new regions the region is added to farConnected.  No action is taken if both reads in the
175         pair are in known genes.
176     """
177     flag1, flag2 = flags
178     total = 0
179     distinct = 0
180     read1IsGood = flag1 in goodDict
181     read2IsGood = flag2 in goodDict
182
183     if read1IsGood and read2IsGood:
184         if flag1 < flag2:
185             geneID = flag1
186             farFlag = flag2
187         else:
188             geneID = flag2
189             farFlag = flag1
190
191         try:
192             farConnected[geneID].append(farFlag)
193         except KeyError:
194             farConnected[geneID] = [farFlag]
195     elif read1IsGood or read2IsGood:
196         total = 1
197         if read2IsGood:
198             farFlag = flag2
199             geneID = flag1
200         else:
201             farFlag = flag1
202             geneID = flag2
203
204         symbol = getGeneSymbol(genome, geneID, geneInfoDict, geneAnnotDict)
205         if farFlag not in assigned:
206             assigned[farFlag] = (symbol, geneID)
207             print "%s %s %s" % (symbol, geneID, goodDict[farFlag].strip())
208             outfile.write("%s %s %s" % (symbol, geneID, goodDict[farFlag]))
209             distinct = 1
210
211     return total, distinct
212
213
214 def getGeneSymbol(genome, geneID, geneInfoDict, geneAnnotDict):
215     try:
216         if genome == "dmelanogaster":
217             symbol = geneInfoDict["Dmel_%s" % geneID][0][0]
218         else:
219             symbol = geneInfoDict[geneID][0][0]
220     except (KeyError, IndexError):
221         try:
222             symbol = geneAnnotDict[(genome, geneID)][0]
223         except (KeyError, IndexError):
224             symbol = "LOC%s" % geneID
225
226     symbol = symbol.strip()
227     symbol = symbol.replace(" ","|")
228     symbol = symbol.replace("\t","|")
229
230     return symbol
231
232
233 def writeUnassignedEntriesToFile(farConnected, assigned, goodDict, outfile):
234     total, written = writeUnassignedPairsToFile(farConnected, assigned, goodDict, outfile)
235     writeUnassignedGoodReadsToFile(total, goodDict, assigned, outfile)
236
237     return written
238
239
240 def writeUnassignedPairsToFile(farConnected, assigned, goodDict, outfile):
241     total = 0
242     written = 0
243     for farFlag in farConnected:
244         geneID = ""
245         symbol = ""
246         idList = [farFlag] + farConnected[farFlag]
247         for ID in idList:
248             if ID in assigned:
249                 (symbol, geneID) = assigned[ID]
250
251         if geneID == "":
252             total += 1
253             symbol = "FAR%d" % total
254             geneID = -1 * total
255
256         for ID in idList:
257             if ID not in assigned:
258                 print "%s %s %s" % (symbol, geneID, goodDict[ID].strip())
259                 outfile.write("%s %s %s" % (symbol, geneID, goodDict[ID]))
260                 written += 1
261                 assigned[ID] = (symbol, geneID)
262
263     return total, written
264
265
266 def writeUnassignedGoodReadsToFile(farIndex, goodDict, assigned, outfile):
267     for farFlag in goodDict:
268         if farFlag not in assigned:
269             farIndex += 1
270             line = "FAR%d %d %s" % (farIndex, -1 * farIndex, goodDict[farFlag])
271             print line.strip()
272             outfile.write(line)
273
274
275 if __name__ == "__main__":
276     main(sys.argv)