rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
[erange.git] / regionCounts.py
1 #
2 #  regionCounts.py
3 #  ENRAGE
4 #
5
6 try:
7     import psyco
8     psyco.full()
9 except:
10     print 'psyco not running'
11
12 import sys
13 import string
14 import optparse
15 from commoncode import getMergedRegions, findPeak, writeLog, getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption
16 import ReadDataset
17
18 versionString = "regionCounts: version 3.10"
19 print versionString
20
21 def main(argv=None):
22     if not argv:
23         argv = sys.argv
24
25     usage = "usage: python %prog regionfile rdsfile outfilename [options]"
26
27     parser = getParser(usage)
28     (options, args) = parser.parse_args(argv[1:])
29
30     if len(args) < 3:
31         print usage
32         sys.exit(1)
33
34     regionfilename = args[0]
35     hitfile =  args[1]
36     outfilename = args[2]
37
38     regionCounts(regionfilename, hitfile, outfilename, options.flagRDS, options.cField,
39                  options.useFullchrom, options.normalize, options.padregion,
40                  options.mergeregion, options.merging, options.doUniqs, options.doMulti,
41                  options.doSplices, options.usePeak, options.cachePages, options.logfilename,
42                  options.doRPKM, options.doLength, options.forceRegion)
43
44
45 def getParser(usage):
46     parser = optparse.OptionParser(usage=usage)
47     parser.add_option("--markRDS", action="store_true", dest="flagRDS")
48     parser.add_option("--chromField", type="int", dest="cField")
49     parser.add_option("--fullchrom", action="store_true", dest="useFullchrom")
50     parser.add_option("--raw", action="store_false", dest="normalize")
51     parser.add_option("--padregion", type="int", dest="padregion")
52     parser.add_option("--mergeregion", type="int", dest="mergeregion")
53     parser.add_option("--nomerge", action="store_false", dest="merging")
54     parser.add_option("--noUniqs", action="store_false", dest="doUniqs")
55     parser.add_option("--noMulti", action="store_false", dest="doMulti")
56     parser.add_option("--splices", action="store_true", dest="doSplices")
57     parser.add_option("--peak", action="store_true", dest="usePeak")
58     parser.add_option("--cache", type="int", dest="cachePages")
59     parser.add_option("--log", dest="logfilename")
60     parser.add_option("--rpkm", action="store_true", dest="doRPKM")
61     parser.add_option("--length", action="store_true", dest="doLength")
62     parser.add_option("--force", action="store_true", dest="forceRegion")
63
64     configParser = getConfigParser()
65     section = "regionCounts"
66     flagRDS = getConfigBoolOption(configParser, section, "flagRDS", False)
67     cField = getConfigIntOption(configParser, section, "cField", 1)
68     useFullchrom = getConfigBoolOption(configParser, section, "useFullchrom", False)
69     normalize = getConfigBoolOption(configParser, section, "normalize", True)
70     padregion = getConfigIntOption(configParser, section, "padregion", 0)
71     mergeregion = getConfigIntOption(configParser, section, "mergeregion", 0)
72     merging = getConfigBoolOption(configParser, section, "merging", True)
73     doUniqs = getConfigBoolOption(configParser, section, "doUniqs", True)
74     doMulti = getConfigBoolOption(configParser, section, "doMulti", True)
75     doSplices = getConfigBoolOption(configParser, section, "doSplices", False)
76     usePeak = getConfigBoolOption(configParser, section, "usePeak", False)
77     cachePages = getConfigIntOption(configParser, section, "cachePages", -1)
78     logfilename = getConfigOption(configParser, section, "logfilename", "regionCounts.log")
79     doRPKM = getConfigBoolOption(configParser, section, "doRPKM", False)
80     doLength = getConfigBoolOption(configParser, section, "doLength", False)
81     forceRegion = getConfigBoolOption(configParser, section, "forceRegion", False)
82
83     parser.set_defaults(flagRDS=flagRDS, cField=cField, useFullchrom=useFullchrom, normalize=normalize,
84                         padregion=padregion, mergeregion=mergeregion, merging=merging, doUniqs=doUniqs,
85                         doMulti=doMulti, doSplices=doSplices, usePeak=usePeak, cachePages=cachePages,
86                         logfilename=logfilename, doRPKM=doRPKM, doLength=doLength,
87                         forceRegion=forceRegion)
88
89     return parser
90
91
92 def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
93                  useFullchrom=False, normalize=True, padregion=0, mergeregion=0,
94                  merging=True, doUniqs=True, doMulti=True, doSplices=False, usePeak=False,
95                  cachePages=-1, logfilename="regionCounts.log", doRPKM=False, doLength=False,
96                  forceRegion=False):
97
98     print "padding %d bp on each side of a region" % padregion
99     print "merging regions closer than %d bp" % mergeregion
100     print "will use peak values"
101
102     if cachePages != -1:
103         doCache = True
104     else:
105         doCache = False
106
107     normalize = True
108     doRPKM = False
109     if doRPKM == True:
110         normalize = True
111
112     writeLog(logfilename, versionString, string.join(sys.argv[1:]))
113
114     regionDict = getMergedRegions(regionfilename, maxDist=mergeregion, minHits=-1, keepLabel=True,
115                                   fullChrom=useFullchrom, verbose=True, chromField=cField,
116                                   doMerge=merging, pad=padregion)
117
118     labelList = []
119     labeltoRegionDict = {}
120     regionCount = {}
121
122     hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
123     readlen = hitRDS.getReadSize()
124     if cachePages > hitRDS.getDefaultCacheSize():
125         hitRDS.setDBcache(cachePages)
126
127     totalCount = len(hitRDS)
128     if normalize:
129         normalizationFactor = totalCount / 1000000.
130
131     chromList = hitRDS.getChromosomes(fullChrom=useFullchrom)
132     if len(chromList) == 0 and doSplices:
133         chromList = hitRDS.getChromosomes(table="splices", fullChrom=useFullchrom)
134
135     chromList.sort()
136
137     if flagRDS:
138         hitRDS.setSynchronousPragma("OFF")        
139
140     for rchrom in regionDict:
141         if forceRegion and rchrom not in chromList:
142             print rchrom
143             for region in regionDict[rchrom]:
144                 regionCount[region.label] = 0
145                 labelList.append(region.label)
146                 labeltoRegionDict[region.label] = (rchrom, region.start, region.stop)
147
148     for rchrom in chromList:
149         regionList = []
150         if rchrom not in regionDict:
151             continue
152
153         print rchrom
154         if useFullchrom:
155             fullchrom = rchrom
156         else:
157             fullchrom = "chr%s" % rchrom
158
159         if usePeak:
160             readDict = hitRDS.getReadsDict(chrom=fullchrom, withWeight=True, doMulti=True, findallOptimize=True)
161             rindex = 0
162             dictLen = len(readDict[fullchrom])
163
164         for region in regionDict[rchrom]:
165             label = region.label
166             start = region.start
167             stop = region.stop
168             regionCount[label] = 0
169             labelList.append(label)
170             labeltoRegionDict[label] = (rchrom, start, stop)
171             regionList.append((label, fullchrom, start, stop))
172             if usePeak:
173                 readList = []
174                 for localIndex in xrange(rindex, dictLen):
175                     read = readDict[fullchrom][localIndex]
176                     if read["start"] < start:
177                         rindex += 1
178                     elif start <= read["start"] <= stop:
179                         readList.append(read)
180                     else:
181                         break
182
183                 if len(readList) < 1:
184                     continue
185
186                 readList.sort()
187                 peak = findPeak(readList, start, stop - start, readlen, doWeight=True)
188                 try:
189                     topValue = peak.smoothArray[peak.topPos[0]]
190                 except:
191                     print "problem with %s %s" % (str(peak.topPos), str(peak.smoothArray))
192                     continue
193
194                 regionCount[label] += topValue
195             else:
196                 regionCount[label] += hitRDS.getCounts(fullchrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices)
197
198         if flagRDS:
199             hitRDS.flagReads(regionList, uniqs=doUniqs, multi=doMulti, splices=doSplices)
200
201     if flagRDS:
202         hitRDS.setSynchronousPragma("ON")    
203
204     if normalize:
205         for label in regionCount:
206             regionCount[label] = float(regionCount[label]) / normalizationFactor
207
208     outfile = open(outfilename, "w")
209
210     if forceRegion:
211         labelList.sort()
212
213     for label in labelList:
214         (chrom, start, stop) = labeltoRegionDict[label]
215         if useFullchrom:
216             fullchrom = chrom
217         else:
218             fullchrom = "chr%s" % chrom
219
220         if normalize:
221             if doRPKM:
222                 length = abs(stop - start) / 1000.
223             else:
224                 length = 1.
225
226             if length < 0.001:
227                 length = 0.001
228
229             outfile.write("%s\t%s\t%d\t%d\t%.2f" % (label, fullchrom, start, stop, regionCount[label]/length))
230             if doLength:
231                 outfile.write("\t%.1f" % length)
232         else:
233             outfile.write('%s\t%s\t%d\t%d\t%d' % (label, fullchrom, start, stop, regionCount[label]))
234
235         outfile.write("\n")
236
237     outfile.close()
238     if doCache and flagRDS:
239         hitRDS.saveCacheDB(hitfile)
240
241     writeLog(logfilename, versionString, "returned %d region counts for %s (%.2f M reads)" % (len(labelList), hitfile, totalCount / 1000000.))
242
243
244 if __name__ == "__main__":
245     main(sys.argv)