snapshot of 4.0a development. initial git repo commit
[erange.git] / regionCounts.py
1 #
2 #  regionCounts.py
3 #  ENRAGE
4 #
5
6 try:
7     import psyco
8     psyco.full()
9 except:
10     print 'psyco not running'
11
12 import sys, string, optparse
13 from commoncode import readDataset, getMergedRegions, findPeak, writeLog
14
15 versionString = "%prog: version 3.9"
16 print versionString
17
18 def main(argv=None):
19     if not argv:
20         argv = sys.argv
21
22     usage = "usage: python %prog regionfile rdsfile outfilename [options]"
23
24     parser = optparse.OptionParser(usage=usage)
25     parser.add_option("--markRDS", action="store_true", dest="flagRDS")
26     parser.add_option("--chromField", type="int", dest="cField")
27     parser.add_option("--fullchrom", action="store_true", dest="useFullchrom")
28     parser.add_option("--raw", action="store_false", dest="normalize")
29     parser.add_option("--padregion", type="int", dest="padregion")
30     parser.add_option("--mergeregion", type="int", dest="mergeregion")
31     parser.add_option("--nomerge", action="store_false", dest="merging")
32     parser.add_option("--noUniqs", action="store_false", dest="doUniqs")
33     parser.add_option("--noMulti", action="store_false", dest="doMulti")
34     parser.add_option("--splices", action="store_true", dest="doSplices")
35     parser.add_option("--peak", action="store_true", dest="usePeak")
36     parser.add_option("--cache", type="int", dest="cachePages")
37     parser.add_option("--log", dest="logfilename")
38     parser.add_option("--rpkm", action="store_true", dest="doRPKM")
39     parser.add_option("--length", action="store_true", dest="doLength")
40     parser.add_option("--force", action="store_true", dest="forceRegion")
41     parser.set_defaults(flagRDS=False, cField=1, useFullchrom=False, normalize=True,
42                         padregion=0, mergeregion=0, merging=True, doUniqs=True,
43                         doMulti=True, doSplices=False, usePeak=False, cachePages=-1,
44                         logfilename="regionCounts.log", doRPKM=False, doLength=False,
45                         forceRegion=False)
46
47     (options, args) = parser.parse_args(argv[1:])
48
49     if len(args) < 3:
50         print usage
51         sys.exit(1)
52
53     regionfilename = args[0]
54     hitfile =  args[1]
55     outfilename = args[2]
56
57     regionCounts(regionfilename, hitfile, outfilename, options.flagRDS, options.cField,
58                  options.useFullchrom, options.normalize, options.padregion,
59                  options.mergeregion, options.merging, options.doUniqs, options.doMulti,
60                  options.doSplices, options.usePeak, options.cachePages, options.logfilename,
61                  options.doRPKM, options.doLength, options.forceRegion)
62
63
64 def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
65                  useFullchrom=False, normalize=True, padregion=0, mergeregion=0,
66                  merging=True, doUniqs=True, doMulti=True, doSplices=False, usePeak=False,
67                  cachePages=-1, logfilename="regionCounts.log", doRPKM=False, doLength=False,
68                  forceRegion=False):
69
70     print "padding %d bp on each side of a region" % padregion
71     print "merging regions closer than %d bp" % mergeregion
72     print "will use peak values"
73
74     if cachePages != -1:
75         doCache = True
76     else:
77         doCache = False
78
79     normalize = True
80     doRPKM = False
81     if doRPKM == True:
82         normalize = True
83
84     writeLog(logfilename, versionString, string.join(sys.argv[1:]))
85
86     regionDict = getMergedRegions(regionfilename, maxDist=mergeregion, minHits=-1, keepLabel=True,
87                                   fullChrom=useFullchrom, verbose=True, chromField=cField,
88                                   doMerge=merging, pad=padregion)
89
90     labelList = []
91     labeltoRegionDict = {}
92     regionCount = {}
93
94     hitRDS = readDataset(hitfile, verbose=True, cache=doCache)
95     readlen = hitRDS.getReadSize()
96     if cachePages > hitRDS.getDefaultCacheSize():
97         hitRDS.setDBcache(cachePages)
98
99     totalCount = len(hitRDS)
100     if normalize:
101         normalizationFactor = totalCount / 1000000.
102
103     chromList = hitRDS.getChromosomes(fullChrom=useFullchrom)
104     if len(chromList) == 0 and doSplices:
105         chromList = hitRDS.getChromosomes(table="splices", fullChrom=useFullchrom)
106
107     chromList.sort()
108
109     if flagRDS:
110         hitRDS.setSynchronousPragma("OFF")        
111
112     for rchrom in regionDict:
113         if forceRegion and rchrom not in chromList:
114             print rchrom
115             for (label, start, stop, length) in regionDict[rchrom]:
116                 regionCount[label] = 0
117                 labelList.append(label)
118                 labeltoRegionDict[label] = (rchrom, start, stop)
119
120     for rchrom in chromList:
121         regionList = []
122         if rchrom not in regionDict:
123             continue
124
125         print rchrom
126         if useFullchrom:
127             fullchrom = rchrom
128         else:
129             fullchrom = "chr%s" % rchrom
130
131         if usePeak:
132             readDict = hitRDS.getReadsDict(chrom=fullchrom, withWeight=True, doMulti=True, findallOptimize=True)
133             rindex = 0
134             dictLen = len(readDict[fullchrom])
135
136         for (label, start, stop, length) in regionDict[rchrom]:
137             regionCount[label] = 0
138             labelList.append(label)
139             labeltoRegionDict[label] = (rchrom, start, stop)
140
141         if useFullchrom:
142             fullchrom = rchrom
143         else:
144             fullchrom = "chr%s" % rchrom
145
146         for (label, rstart, rstop, length) in regionDict[rchrom]:
147             regionList.append((label, fullchrom, rstart, rstop))
148             if usePeak:
149                 readList = []
150                 for localIndex in xrange(rindex, dictLen):
151                     read = readDict[fullchrom][localIndex]
152                     if read[0] < rstart:
153                         rindex += 1
154                     elif rstart <= read[0] <= rstop:
155                         readList.append(read)
156                     else:
157                         break
158
159                 if len(readList) < 1:
160                     continue
161
162                 readList.sort()
163                 (topPos, numHits, smoothArray, numPlus) = findPeak(readList, rstart, rstop - rstart, readlen, doWeight=True)
164                 try:
165                     topValue = smoothArray[topPos[0]]
166                 except:
167                     print "problem with %s %s" % (str(topPos), str(smoothArray))
168                     continue
169
170                 regionCount[label] += topValue
171             else:
172                 regionCount[label] += hitRDS.getCounts(fullchrom, rstart, rstop, uniqs=doUniqs, multi=doMulti, splices=doSplices)
173
174         if flagRDS:
175             hitRDS.flagReads(regionList, uniqs=doUniqs, multi=doMulti, splices=doSplices)
176
177     if flagRDS:
178         hitRDS.setSynchronousPragma("ON")    
179
180     if normalize:
181         for label in regionCount:
182             regionCount[label] = float(regionCount[label]) / normalizationFactor
183
184     outfile = open(outfilename, "w")
185
186     if forceRegion:
187         labelList.sort()
188
189     for label in labelList:
190         (chrom, start, stop) = labeltoRegionDict[label]
191         if useFullchrom:
192             fullchrom = chrom
193         else:
194             fullchrom = "chr%s" % chrom
195
196         if normalize:
197             if doRPKM:
198                 length = abs(stop - start) / 1000.
199             else:
200                 length = 1.
201
202             if length < 0.001:
203                 length = 0.001
204
205             outfile.write("%s\t%s\t%d\t%d\t%.2f" % (label, fullchrom, start, stop, regionCount[label]/length))
206             if doLength:
207                 outfile.write("\t%.1f" % length)
208         else:
209             outfile.write('%s\t%s\t%d\t%d\t%d' % (label, fullchrom, start, stop, regionCount[label]))
210
211         outfile.write("\n")
212
213     outfile.close()
214     if doCache and flagRDS:
215         hitRDS.saveCacheDB(hitfile)
216
217     writeLog(logfilename, versionString, "returned %d region counts for %s (%.2f M reads)" % (len(labelList), hitfile, totalCount / 1000000.))
218
219
220 if __name__ == "__main__":
221     main(sys.argv)