erange version 4.0a dev release
[erange.git] / weighMultireads.py
1 #
2 #  weightMultireads.py
3 #  ENRAGE
4 #
5
6 #  Created by Ali Mortazavi on 10/02/08.
7 #
8
9 try:
10     import psyco
11     psyco.full()
12 except:
13     pass
14
15 import sys
16 import time
17 import string
18 import optparse
19 import ReadDataset
20 from commoncode import getConfigParser, getConfigBoolOption, getConfigOption
21
22
23 print "weighMultireads: version 3.3"
24
25 def main(argv=None):
26     if not argv:
27         argv = sys.argv
28
29     usage = "usage: python %s rdsfile [--radius bp] [--noradius] [--usePairs maxDist] [--verbose] [--cache pages]"
30
31     parser = getParser(usage)
32     (options, args) = parser.parse_args(argv[1:])
33
34     if len(args) < 1:
35         print usage
36         sys.exit(1)
37
38     rdsfile = args[0]
39
40     weighMultireads(rdsfile, options.radius, options.doRadius, options.pairDist, options.verbose, options.cachePages)
41
42
43 def getParser(usage):
44     parser = optparse.OptionParser(usage=usage)
45     parser.add_option("--radius", type="int", dest="radius")
46     parser.add_option("--noradius", action="store_false", dest="doRadius")
47     parser.add_option("--usePairs", type="int", dest="pairDist")
48     parser.add_option("--verbose", action="store_true", dest="verbose")
49     parser.add_option("--cache", type="int", dest="cachePages")
50
51     configParser = getConfigParser()
52     section = "weighMultireads"
53     radius = getConfigOption(configParser, section, "radius", None)
54     doRadius = getConfigBoolOption(configParser, section, "doRadius", True)
55     pairDist = getConfigOption(configParser, section, "pairDist", None)
56     verbose = getConfigBoolOption(configParser, section, "verbose", False)
57     cachePages = getConfigOption(configParser, section, "cachePages", None)
58     
59     parser.set_defaults(radius=radius, doRadius=doRadius, pairDist=pairDist, verbose=verbose, cachePages=cachePages)
60
61     return parser
62
63
64 def weighMultireads(rdsfile, radius=None, doRadius=True, pairDist=None, verbose=False, cachePages=None):
65
66     if cachePages is not None:
67         doCache = True
68     else:
69         doCache = False
70         cachePages = 1
71
72     RDS = ReadDataset.ReadDataset(rdsfile, verbose = True, cache=doCache)
73     if cachePages > RDS.getDefaultCacheSize():
74         RDS.setDBcache(cachePages)
75
76     if verbose:
77         print time.ctime()
78
79     multiIDs = RDS.getReadIDs(uniqs=False, multi=True)
80     if verbose:
81         print "got multiIDs ", time.ctime()
82
83     fixedReads = []
84     if pairDist is not None:
85         fixedReads = reweighUsingPairs(RDS, pairDist, multiIDs, verbose)
86
87     if radius is not None:
88         doRadius = True
89     else:
90         radius = 100
91
92     if doRadius:
93         reweighUsingRadius(RDS, radius, multiIDs, fixedReads, verbose)
94
95     if doCache:
96         RDS.saveCacheDB(rdsfile)
97
98     if verbose:
99         print "finished", time.ctime()
100
101
102 def reweighUsingPairs(RDS, pairDist, multiIDs, verbose=False):
103     fixedPair = 0
104     tooFar = pairDist * 10
105     readlen = RDS.getReadSize()
106     fixedReads = []
107     print "doing pairs with pairDist = %d" % pairDist
108     hasSplices = RDS.dataType == "RNA"
109     uniqIDs = RDS.getReadIDs(uniqs=True, multi=False, splices=hasSplices)
110
111     if verbose:
112         print "got uniqIDs ", time.ctime()
113
114     jointList, bothMultiList = getReadIDLists(uniqIDs, multiIDs, verbose)
115     uniqDict = getUniqAndSpliceReadsFromReadIDs(RDS, jointList, verbose)
116     if verbose:
117         print "guDict actual ", len(uniqDict), time.ctime()
118
119     multiDict = getMultiReadsFromReadIDs(RDS, jointList, bothMultiList, verbose)
120     if verbose:
121         print "muDict actual ", len(multiDict), time.ctime()
122
123     RDS.setSynchronousPragma("OFF")
124     for readID in jointList:
125         try:
126             ustart = uniqDict[readID]["start"]
127             ustop = ustart + readlen
128         except KeyError:
129             ustart = uniqDict[readID]["startL"]
130             ustop = uniqDict[readID]["stopR"]
131
132         uniqReadChrom = uniqDict[readID]["chrom"]
133         multiReadList = multiDict[readID]
134         numMultiReads = len(multiReadList)
135         bestMatch = [tooFar] * numMultiReads
136         found = False
137         for index in range(numMultiReads):
138             mstart = multiReadList[index]["start"]
139             multiReadChrom = multiReadList[index]["chrom"]
140             mpair = multiReadList[index]["pairID"]
141             if uniqReadChrom != multiReadChrom:
142                 continue
143
144             if abs(mstart - ustart) < pairDist:
145                 bestMatch[index] = abs(mstart - ustart)
146                 found = True
147             elif abs(mstart - ustop) < pairDist:
148                 bestMatch[index] = abs(mstart - ustop)
149                 found = True
150
151         if found:
152             theMatch = -1
153             theDist = tooFar
154             reweighList = []
155             for index in range(numMultiReads):
156                 if theDist > bestMatch[index]:
157                     theMatch = index
158                     theDist = bestMatch[index]
159
160             theID = string.join([readID, mpair], "/")
161             for index in range(numMultiReads):
162                 if index == theMatch:
163                     score = 1 - ((numMultiReads - 1) / (100. * numMultiReads))
164                 else:
165                     score = 1 / (100. * numMultiReads)
166
167                 start = multiReadList[index][0]
168                 chrom = "chr%s" % multiReadList[index][1]
169                 reweighList.append((round(score,3), chrom, start, theID))
170
171             #TODO: Is this right? If match index is 0 are we doing nothing?
172             if theMatch > 0:
173                 RDS.reweighMultireads(reweighList)
174                 fixedPair += 1
175                 if verbose and fixedPair % 10000 == 1:
176                     print "fixed %d" % fixedPair
177                     print uniqDict[readID]
178                     print multiDict[readID]
179                     print reweighList
180
181                 fixedReads.append(theID)
182
183     RDS.setSynchronousPragma("ON")
184
185     print "fixed %d pairs" % fixedPair
186     print time.ctime()
187
188     return fixedReads
189
190
191 def getReadIDLists(uniqIDs, multiIDs, verbose=False):
192     uidDict = {}
193     mainIDList = []
194     for readID in uniqIDs:
195         (mainID, pairID) = readID.split("/")
196         try:
197             uidDict[mainID].append(pairID)
198         except:
199             uidDict[mainID] = [pairID]
200             mainIDList.append(mainID)
201
202     if verbose:
203         print "uidDict all ", len(uidDict), time.ctime()
204
205     for mainID in mainIDList:
206         if len(uidDict[mainID]) == 2:
207             del uidDict[mainID]
208
209     if verbose:
210         print "uidDict first candidates ", len(uidDict), time.ctime()
211
212     midDict = {}
213     for readID in multiIDs:
214         (frontID, multiplicity) = readID.split("::")
215         (mainID, pairID) = frontID.split("/")
216         try:
217             if pairID not in midDict[mainID]:
218                 midDict[mainID].append(pairID)
219         except:
220             midDict[mainID] = [pairID]
221
222     if verbose:
223         print "all multis ", len(midDict), time.ctime()
224
225     mainIDList = uidDict.keys()
226     for mainID in mainIDList:
227         if mainID not in midDict:
228             del uidDict[mainID]
229
230     if verbose:
231         print "uidDict actual candidates ", len(uidDict), time.ctime()
232
233     jointList = []
234     bothMultiList = []
235     for readID in midDict:
236         listLen = len(midDict[readID])
237         if listLen == 1:
238             if readID in uidDict:
239                 jointList.append(readID)
240         elif listLen == 2:
241             bothMultiList.append(readID)
242
243     if verbose:
244         print "joint ", len(jointList), time.ctime()
245         print "bothMulti ", len(bothMultiList), time.ctime()
246
247     return jointList, bothMultiList
248
249
250 def getUniqAndSpliceReadsFromReadIDs(RDS, jointList, verbose=False):
251     uniqReadsDict = {}
252     uniqDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=True, readIDDict=True)
253     if verbose:
254         print "got uniq dict ", len(uniqDict), time.ctime()
255
256     if RDS.dataType == "RNA":
257         spliceDict = RDS.getSplicesDict(noSense=True, withChrom=True, withPairID=True, readIDDict=True)
258         if verbose:
259             print "got splice dict ", len(spliceDict), time.ctime()
260
261     for readID in jointList:
262         try:
263             uniqReadsDict[readID] = uniqDict[readID][0]
264         except KeyError:
265             if RDS.dataType == "RNA":
266                 uniqReadsDict[readID] = spliceDict[readID][0]
267
268     return uniqReadsDict
269
270
271 def getMultiReadsFromReadIDs(RDS, jointList, bothMultiList, verbose=False):
272     multiReadSubsetDict = {}
273     multiDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=False, doMulti=True, readIDDict=True)
274     if verbose:
275         print "got multi dict ", len(multiDict), time.ctime()
276
277     for readID in jointList:
278         multiReadSubsetDict[readID] = multiDict[readID]
279
280     for readID in bothMultiList:
281         multiReadSubsetDict[readID] = multiDict[readID]
282
283     return multiReadSubsetDict
284
285
286 def reweighUsingRadius(RDS, radius, multiIDs, readsToSkip=[], verbose=False):
287     skippedReads = 0
288     readlen = RDS.getReadSize()
289     halfreadlen = readlen / 2
290     print "doing uniq read radius with radius = %d" % radius
291     multiDict = RDS.getReadsDict(noSense=True, withWeight=True, withChrom=True, withID=True, doUniqs=False, doMulti=True, readIDDict=True)
292     print "got multiDict"
293     RDS.setSynchronousPragma("OFF")
294     reweighedCount = 0
295     for readID in multiIDs:
296         originalMultiReadID = readID
297         if originalMultiReadID in readsToSkip:
298             skippedReads += 1
299             continue
300
301         if "::" in readID:
302             (readID, multiplicity) = readID.split("::")
303
304         scores = []
305         coords = []
306         for read in multiDict[readID]:
307             start = read["start"]
308             chromosome = "chr%s" % read["chrom"]
309             regionStart = start + halfreadlen - radius
310             regionStop = start + halfreadlen + radius 
311             uniqs = RDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=False, splices=False, reportCombined=True)
312             scores.append(uniqs + 1)
313             coords.append((chromosome, start, originalMultiReadID))
314
315         total = float(sum(scores))
316         reweighList = []
317         for index in range(len(scores)):
318             reweighList.append((round(scores[index]/total,2), coords[index][0], coords[index][1], coords[index][2]))
319
320         RDS.reweighMultireads(reweighList)
321         reweighedCount += 1
322         if reweighedCount % 10000 == 0:
323             print reweighedCount
324
325     RDS.setSynchronousPragma("ON")
326     if verbose:
327         print "skipped ", skippedReads
328
329     print "reweighted ", reweighedCount
330
331
332 if __name__ == "__main__":
333     main(sys.argv)