snapshot of 4.0a development. initial git repo commit
[erange.git] / weighMultireads.py
1 #
2 #  weightMultireads.py
3 #  ENRAGE
4 #
5
6 #  Created by Ali Mortazavi on 10/02/08.
7 #
8
9 try:
10     import psyco
11     psyco.full()
12 except:
13     pass
14
15 from commoncode import readDataset
16 import sys, time, string, optparse
17
18 print "%prog: version 3.1"
19
20 def main(argv=None):
21     if not argv:
22         argv = sys.argv
23
24     usage = "usage: python %s rdsfile [--radius bp] [--noradius] [--usePairs maxDist] [--verbose] [--cache pages]"
25
26     parser = optparse.OptionParser(usage=usage)
27     parser.add_option("--radius", type="int", dest="radius")
28     parser.add_option("--noradius", action="store_false", dest="doRadius")
29     parser.add_option("--usePairs", type="int", dest="pairDist")
30     parser.add_option("--verbose", action="store_true", dest="verbose")
31     parser.add_option("--cache", type="int", dest="cachePages")
32     parser.set_defaults(radius=None, doRadius=True, pairDist=None, verbose=False, cachePages=None)
33     (options, args) = parser.parse_args(argv[1:])
34
35     if len(args) < 1:
36         print usage
37         sys.exit(1)
38
39     rdsfile = args[0]
40
41     weighMultireads(rdsfile, options.radius, options.doRadius, options.pairDist, options.verbose, options.cachePages)
42
43
44 def weighMultireads(rdsfile, radius=None, doRadius=True, pairDist=None, verbose=False, cachePages=None):
45
46     if radius is not None:
47         doRadius = True
48     else:
49         radius = 100
50
51     usePairs = False
52     if pairDist is not None:
53         usePairs = True
54
55     tooFar = pairDist * 10
56     
57     doCache = False
58     if cachePages is not None:
59         doCache = True
60     else:
61         cachePages = 1
62
63     RDS = readDataset(rdsfile, verbose = True, cache=doCache)
64     readlen = RDS.getReadSize()
65     halfreadlen = readlen / 2
66
67     if cachePages > RDS.getDefaultCacheSize():
68         RDS.setDBcache(cachePages)
69
70     if verbose:
71         print time.ctime()
72
73     multiIDs = RDS.getReadIDs(uniqs=False,multi=True)
74     if verbose:
75         print "got multiIDs ", time.ctime()
76
77     fixedPair = 0
78     fixedReads = []
79     if usePairs:
80         print "doing pairs with pairDist = %d" % pairDist
81         uidDict = {}
82         midDict = {}
83         jointList = []
84         bothMultiList = []
85         mainIDList = []
86         guDict = {}
87         muDict = {}
88
89         if RDS.dataType == "RNA":
90             uniqIDs = RDS.getReadIDs(uniqs=True,multi=False,splices=True)
91         else:
92             uniqIDs = RDS.getReadIDs(uniqs=True,multi=False,splices=False)
93
94         if verbose:
95             print "got uniqIDs ", time.ctime()
96
97         for readID in uniqIDs:
98             (mainID, pairID) = readID.split("/")
99             try:
100                 uidDict[mainID].append(pairID)
101             except:
102                 uidDict[mainID] = [pairID]
103                 mainIDList.append(mainID)
104
105         if verbose:
106             print "uidDict all ", len(uidDict), time.ctime()
107
108         for mainID in mainIDList:
109             if len(uidDict[mainID]) == 2:
110                 del uidDict[mainID]
111
112         if verbose:
113             print "uidDict first candidates ", len(uidDict), time.ctime()
114
115         for readID in multiIDs:
116             (frontID, multiplicity) = readID.split("::")
117             (mainID, pairID) = frontID.split("/")
118             try:
119                 if pairID not in midDict[mainID]:
120                     midDict[mainID].append(pairID)
121             except:
122                 midDict[mainID] = [pairID]
123
124         if verbose:
125             print "all multis ", len(midDict), time.ctime()
126
127         mainIDList = uidDict.keys()
128         for mainID in mainIDList:
129             if mainID not in midDict:
130                 del uidDict[mainID]
131
132         if verbose:
133             print "uidDict actual candidates ", len(uidDict), time.ctime()
134
135         for readID in midDict:
136             listLen = len(midDict[readID])
137             if listLen == 1:
138                 if readID in uidDict:
139                     jointList.append(readID)
140             elif listLen == 2:
141                 bothMultiList.append(readID)
142
143         if verbose:
144             print "joint ", len(jointList), time.ctime()
145             print "bothMulti ", len(bothMultiList), time.ctime()
146
147         del uidDict
148         del midDict
149         del mainIDList
150         del uniqIDs
151
152         uniqDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=True, readIDDict=True)
153         if verbose:
154             print "got uniq dict ", len(uniqDict), time.ctime()
155
156         if RDS.dataType == "RNA":
157             spliceDict = RDS.getSplicesDict(noSense=True, withChrom=True, withPairID=True, readIDDict=True)
158             if verbose:
159                 print "got splice dict ", len(spliceDict), time.ctime()
160
161         for readID in jointList:
162             try:
163                 guDict[readID] = uniqDict[readID][0]
164             except:
165                 if RDS.dataType == "RNA":
166                     guDict[readID] = spliceDict[readID][0]
167
168         del uniqDict
169         del spliceDict
170         if verbose:
171             print "guDict actual ", len(guDict), time.ctime()
172
173         multiDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=False, doMulti=True, readIDDict=True)
174         if verbose:
175             print "got multi dict ", len(multiDict), time.ctime()
176
177         for readID in jointList:
178             muDict[readID] = multiDict[readID]
179
180         for readID in bothMultiList:
181             muDict[readID] = multiDict[readID]
182
183         del multiDict
184         if verbose:
185             print "muDict actual ", len(muDict), time.ctime()
186
187         RDS.setSynchronousPragma("OFF")
188         for readID in jointList:
189             try:
190                 (ustart, uchrom, upair) = guDict[readID]
191                 ustop = ustart + readlen
192             except:
193                 (ustart, lstop, rstart, ustop, uchrom, upair) = guDict[readID]
194
195             muList = muDict[readID]
196             muLen = len(muList)
197             bestMatch = [tooFar] * muLen
198             found = False
199             for index in range(muLen):
200                 (mstart, mchrom, mpair) = muList[index]
201                 if uchrom != mchrom:
202                     continue
203
204                 if abs(mstart - ustart) < pairDist:
205                     bestMatch[index] = abs(mstart - ustart)
206                     found = True
207                 elif abs(mstart - ustop) < pairDist:
208                     bestMatch[index] = abs(mstart - ustop)
209                     found = True
210
211             if found:
212                 theMatch = -1
213                 theDist = tooFar
214                 reweighList = []
215                 for index in range(muLen):
216                     if theDist > bestMatch[index]:
217                         theMatch = index
218                         theDist = bestMatch[index]
219
220                 theID = string.join([readID, mpair], "/")
221                 for index in range(muLen):
222                     if index == theMatch:
223                         score = 1 - (muLen - 1) / (100. * (muLen))
224                     else:
225                         score = 1 / (100. * muLen)
226
227                     start = muList[index][0]
228                     chrom = "chr%s" % muList[index][1]
229                     reweighList.append((round(score,3), chrom, start, theID))
230
231                 if theMatch > 0:
232                     RDS.reweighMultireads(reweighList)
233                     fixedPair += 1
234                     if verbose and fixedPair % 10000 == 1:
235                         print "fixed %d" % fixedPair
236                         print guDict[readID]
237                         print muDict[readID]
238                         print reweighList
239
240                     fixedReads.append(theID)
241
242         RDS.setSynchronousPragma("ON")
243
244         del guDict
245         del muDict
246         print "fixed %d pairs" % fixedPair
247         print time.ctime()
248
249     skippedReads = 0
250     if doRadius:
251         print "doing uniq read radius with radius = %d" % radius
252         multiDict = RDS.getReadsDict(noSense=True, withWeight=True, withChrom=True, withID=True, doUniqs=False, doMulti=True, readIDDict=True)
253         print "got multiDict"
254         RDS.setSynchronousPragma("OFF")
255         rindex = 0
256         for readID in multiIDs:
257             theID = readID
258             if theID in fixedReads:
259                 skippedReads += 1
260                 continue
261
262             if "::" in readID:
263                 (readID, multiplicity) = readID.split("::")
264
265             scores = []
266             coords = []
267             for read in multiDict[readID]:
268                 (start, weight, rID, chrom) = read
269                 achrom = "chr%s" % chrom
270                 regionStart = start + halfreadlen - radius
271                 regionStop = start + halfreadlen + radius 
272                 uniqs = RDS.getCounts(achrom, regionStart, regionStop, uniqs=True, multi=False, splices=False, reportCombined=True)
273                 scores.append(uniqs + 1)
274                 coords.append((achrom, start, theID))
275
276             total = float(sum(scores))
277             reweighList = []
278             for index in range(len(scores)):
279                 reweighList.append((round(scores[index]/total,2), coords[index][0], coords[index][1], coords[index][2]))
280
281             RDS.reweighMultireads(reweighList)
282             rindex += 1
283             if rindex % 10000 == 0:
284                 print rindex
285
286         RDS.setSynchronousPragma("ON")
287         if verbose:
288             print "skipped ", skippedReads
289
290         print "reweighted ", rindex
291
292     if doCache:
293         RDS.saveCacheDB(rdsfile)
294
295     if verbose:
296         print "finished", time.ctime()
297     
298
299 if __name__ == "__main__":
300     main(sys.argv)