rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
[erange.git] / weighMultireads.py
1 #
2 #  weightMultireads.py
3 #  ENRAGE
4 #
5
6 #  Created by Ali Mortazavi on 10/02/08.
7 #
8
9 try:
10     import psyco
11     psyco.full()
12 except:
13     pass
14
15 import sys
16 import time
17 import string
18 import optparse
19 import ReadDataset
20 from commoncode import getConfigParser, getConfigBoolOption, getConfigOption
21
22
23 print "weighMultireads: version 3.3"
24
25 def main(argv=None):
26     if not argv:
27         argv = sys.argv
28
29     usage = "usage: python %s rdsfile [--radius bp] [--noradius] [--usePairs maxDist] [--verbose] [--cache pages]"
30
31     parser = getParser(usage)
32     (options, args) = parser.parse_args(argv[1:])
33
34     if len(args) < 1:
35         print usage
36         sys.exit(1)
37
38     rdsfile = args[0]
39
40     weighMultireads(rdsfile, options.radius, options.doRadius, options.pairDist, options.verbose, options.cachePages)
41
42
43 def getParser(usage):
44     parser = optparse.OptionParser(usage=usage)
45     parser.add_option("--radius", type="int", dest="radius")
46     parser.add_option("--noradius", action="store_false", dest="doRadius")
47     parser.add_option("--usePairs", type="int", dest="pairDist")
48     parser.add_option("--verbose", action="store_true", dest="verbose")
49     parser.add_option("--cache", type="int", dest="cachePages")
50
51     configParser = getConfigParser()
52     section = "weighMultireads"
53     radius = getConfigOption(configParser, section, "radius", None)
54     doRadius = getConfigBoolOption(configParser, section, "doRadius", True)
55     pairDist = getConfigOption(configParser, section, "pairDist", None)
56     verbose = getConfigBoolOption(configParser, section, "verbose", False)
57     cachePages = getConfigOption(configParser, section, "cachePages", None)
58     
59     parser.set_defaults(radius=radius, doRadius=doRadius, pairDist=pairDist, verbose=verbose, cachePages=cachePages)
60
61     return parser
62
63
64 def weighMultireads(rdsfile, radius=None, doRadius=True, pairDist=None, verbose=False, cachePages=None):
65
66     if cachePages is not None:
67         doCache = True
68     else:
69         doCache = False
70         cachePages = 1
71
72     RDS = ReadDataset.ReadDataset(rdsfile, verbose = True, cache=doCache)
73     if cachePages > RDS.getDefaultCacheSize():
74         RDS.setDBcache(cachePages)
75
76     if verbose:
77         print time.ctime()
78
79     multiIDs = RDS.getReadIDs(uniqs=False, multi=True)
80     if verbose:
81         print "got multiIDs ", time.ctime()
82
83     fixedReads = []
84     if pairDist is not None:
85         fixedReads = reweighUsingPairs(RDS, pairDist, multiIDs, verbose)
86
87     if radius is not None:
88         doRadius = True
89     else:
90         radius = 100
91
92     if doRadius:
93         reweighUsingRadius(RDS, radius, multiIDs, fixedReads, verbose)
94
95     if doCache:
96         RDS.saveCacheDB(rdsfile)
97
98     if verbose:
99         print "finished", time.ctime()
100
101
102 def reweighUsingPairs(RDS, pairDist, multiIDs, verbose=False):
103     fixedPair = 0
104     tooFar = pairDist * 10
105     readlen = RDS.getReadSize()
106     fixedReads = []
107     print "doing pairs with pairDist = %d" % pairDist
108     hasSplices = RDS.dataType == "RNA"
109     uniqIDs = RDS.getReadIDs(uniqs=True, multi=False, splices=hasSplices)
110
111     if verbose:
112         print "got uniqIDs ", time.ctime()
113
114     jointList, bothMultiList = getReadIDLists(uniqIDs, multiIDs, verbose)
115     uniqDict = getUniqAndSpliceReadsFromReadIDs(RDS, jointList, verbose)
116     if verbose:
117         print "guDict actual ", len(uniqDict), time.ctime()
118
119     multiDict = getMultiReadsFromReadIDs(RDS, jointList, bothMultiList, verbose)
120     if verbose:
121         print "muDict actual ", len(multiDict), time.ctime()
122
123     RDS.setSynchronousPragma("OFF")
124     for readID in jointList:
125         try:
126             ustart = uniqDict[readID]["start"]
127             ustop = ustart + readlen
128         except KeyError:
129             ustart = uniqDict[readID]["startL"]
130             ustop = uniqDict[readID]["stopR"]
131
132         uniqReadChrom = uniqDict[readID]["chrom"]
133         multiReadList = multiDict[readID]
134         numMultiReads = len(multiReadList)
135         bestMatch = [tooFar] * numMultiReads
136         found = False
137         for index in range(numMultiReads):
138             mstart = multiReadList[index]["start"]
139             multiReadChrom = multiReadList[index]["chrom"]
140             mpair = multiReadList[index]["pairID"]
141             if uniqReadChrom != multiReadChrom:
142                 continue
143
144             if abs(mstart - ustart) < pairDist:
145                 bestMatch[index] = abs(mstart - ustart)
146                 found = True
147             elif abs(mstart - ustop) < pairDist:
148                 bestMatch[index] = abs(mstart - ustop)
149                 found = True
150
151         if found:
152             theMatch = -1
153             theDist = tooFar
154             reweighList = []
155             for index in range(numMultiReads):
156                 if theDist > bestMatch[index]:
157                     theMatch = index
158                     theDist = bestMatch[index]
159
160             theID = string.join([readID, mpair], "/")
161             for index in range(numMultiReads):
162                 if index == theMatch:
163                     score = 1 - ((numMultiReads - 1) / (100. * numMultiReads))
164                 else:
165                     score = 1 / (100. * numMultiReads)
166
167                 start = multiReadList[index][0]
168                 chrom = "chr%s" % multiReadList[index][1]
169                 reweighList.append((round(score,3), chrom, start, theID))
170
171             if theMatch > 0:
172                 RDS.reweighMultireads(reweighList)
173                 fixedPair += 1
174                 if verbose and fixedPair % 10000 == 1:
175                     print "fixed %d" % fixedPair
176                     print uniqDict[readID]
177                     print multiDict[readID]
178                     print reweighList
179
180                 fixedReads.append(theID)
181
182     RDS.setSynchronousPragma("ON")
183
184     print "fixed %d pairs" % fixedPair
185     print time.ctime()
186
187     return fixedReads
188
189
190 def getReadIDLists(uniqIDs, multiIDs, verbose=False):
191     uidDict = {}
192     mainIDList = []
193     for readID in uniqIDs:
194         (mainID, pairID) = readID.split("/")
195         try:
196             uidDict[mainID].append(pairID)
197         except:
198             uidDict[mainID] = [pairID]
199             mainIDList.append(mainID)
200
201     if verbose:
202         print "uidDict all ", len(uidDict), time.ctime()
203
204     for mainID in mainIDList:
205         if len(uidDict[mainID]) == 2:
206             del uidDict[mainID]
207
208     if verbose:
209         print "uidDict first candidates ", len(uidDict), time.ctime()
210
211     midDict = {}
212     for readID in multiIDs:
213         (frontID, multiplicity) = readID.split("::")
214         (mainID, pairID) = frontID.split("/")
215         try:
216             if pairID not in midDict[mainID]:
217                 midDict[mainID].append(pairID)
218         except:
219             midDict[mainID] = [pairID]
220
221     if verbose:
222         print "all multis ", len(midDict), time.ctime()
223
224     mainIDList = uidDict.keys()
225     for mainID in mainIDList:
226         if mainID not in midDict:
227             del uidDict[mainID]
228
229     if verbose:
230         print "uidDict actual candidates ", len(uidDict), time.ctime()
231
232     jointList = []
233     bothMultiList = []
234     for readID in midDict:
235         listLen = len(midDict[readID])
236         if listLen == 1:
237             if readID in uidDict:
238                 jointList.append(readID)
239         elif listLen == 2:
240             bothMultiList.append(readID)
241
242     if verbose:
243         print "joint ", len(jointList), time.ctime()
244         print "bothMulti ", len(bothMultiList), time.ctime()
245
246     return jointList, bothMultiList
247
248
249 def getUniqAndSpliceReadsFromReadIDs(RDS, jointList, verbose=False):
250     uniqReadsDict = {}
251     uniqDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=True, readIDDict=True)
252     if verbose:
253         print "got uniq dict ", len(uniqDict), time.ctime()
254
255     if RDS.dataType == "RNA":
256         spliceDict = RDS.getSplicesDict(noSense=True, withChrom=True, withPairID=True, readIDDict=True)
257         if verbose:
258             print "got splice dict ", len(spliceDict), time.ctime()
259
260     for readID in jointList:
261         try:
262             uniqReadsDict[readID] = uniqDict[readID][0]
263         except KeyError:
264             if RDS.dataType == "RNA":
265                 uniqReadsDict[readID] = spliceDict[readID][0]
266
267     return uniqReadsDict
268
269
270 def getMultiReadsFromReadIDs(RDS, jointList, bothMultiList, verbose=False):
271     multiReadSubsetDict = {}
272     multiDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=False, doMulti=True, readIDDict=True)
273     if verbose:
274         print "got multi dict ", len(multiDict), time.ctime()
275
276     for readID in jointList:
277         multiReadSubsetDict[readID] = multiDict[readID]
278
279     for readID in bothMultiList:
280         multiReadSubsetDict[readID] = multiDict[readID]
281
282     return multiReadSubsetDict
283
284
285 def reweighUsingRadius(RDS, radius, multiIDs, readsToSkip=[], verbose=False):
286     skippedReads = 0
287     readlen = RDS.getReadSize()
288     halfreadlen = readlen / 2
289     print "doing uniq read radius with radius = %d" % radius
290     multiDict = RDS.getReadsDict(noSense=True, withWeight=True, withChrom=True, withID=True, doUniqs=False, doMulti=True, readIDDict=True)
291     print "got multiDict"
292     RDS.setSynchronousPragma("OFF")
293     reweighedCount = 0
294     for readID in multiIDs:
295         originalMultiReadID = readID
296         if originalMultiReadID in readsToSkip:
297             skippedReads += 1
298             continue
299
300         if "::" in readID:
301             (readID, multiplicity) = readID.split("::")
302
303         scores = []
304         coords = []
305         for read in multiDict[readID]:
306             start = read["start"]
307             chromosome = "chr%s" % read["chrom"]
308             regionStart = start + halfreadlen - radius
309             regionStop = start + halfreadlen + radius 
310             uniqs = RDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=False, splices=False, reportCombined=True)
311             scores.append(uniqs + 1)
312             coords.append((chromosome, start, originalMultiReadID))
313
314         total = float(sum(scores))
315         reweighList = []
316         for index in range(len(scores)):
317             reweighList.append((round(scores[index]/total,2), coords[index][0], coords[index][1], coords[index][2]))
318
319         RDS.reweighMultireads(reweighList)
320         reweighedCount += 1
321         if reweighedCount % 10000 == 0:
322             print reweighedCount
323
324     RDS.setSynchronousPragma("ON")
325     if verbose:
326         print "skipped ", skippedReads
327
328     print "reweighted ", reweighedCount
329
330
331 if __name__ == "__main__":
332     main(sys.argv)