convert standard analysis pipelines to use bam format natively
[erange.git] / regionCounts.py
index ae005cb3146b686fd3934cd9a152c331900ebde4..e11508ea65e7a61614afb971f88fdb8f00824da7 100755 (executable)
@@ -12,8 +12,8 @@ except:
 import sys
 import string
 import optparse
-from commoncode import getMergedRegions, findPeak, writeLog, getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption
-import ReadDataset
+import pysam
+from commoncode import getMergedRegions, findPeak, writeLog, getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption, getHeaderComment, isSpliceEntry, getReadSense
 
 versionString = "regionCounts: version 3.10"
 print versionString
@@ -32,10 +32,12 @@ def main(argv=None):
         sys.exit(1)
 
     regionfilename = args[0]
-    hitfile =  args[1]
+    bamfilename =  args[1]
     outfilename = args[2]
 
-    regionCounts(regionfilename, hitfile, outfilename, options.flagRDS, options.cField,
+    bamfile = pysam.Samfile(bamfilename, "rb")
+
+    regionCounts(regionfilename, bamfile, outfilename, options.flagRDS, options.cField,
                  options.useFullchrom, options.normalize, options.padregion,
                  options.mergeregion, options.merging, options.doUniqs, options.doMulti,
                  options.doSplices, options.usePeak, options.cachePages, options.logfilename,
@@ -89,7 +91,7 @@ def getParser(usage):
     return parser
 
 
-def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
+def regionCounts(regionfilename, bamfile, outfilename, flagRDS=False, cField=1,
                  useFullchrom=False, normalize=True, padregion=0, mergeregion=0,
                  merging=True, doUniqs=True, doMulti=True, doSplices=False, usePeak=False,
                  cachePages=-1, logfilename="regionCounts.log", doRPKM=False, doLength=False,
@@ -99,11 +101,6 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
     print "merging regions closer than %d bp" % mergeregion
     print "will use peak values"
 
-    if cachePages != -1:
-        doCache = True
-    else:
-        doCache = False
-
     normalize = True
     doRPKM = False
     if doRPKM == True:
@@ -119,24 +116,13 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
     labeltoRegionDict = {}
     regionCount = {}
 
-    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
-    readlen = hitRDS.getReadSize()
-    if cachePages > hitRDS.getDefaultCacheSize():
-        hitRDS.setDBcache(cachePages)
-
-    totalCount = len(hitRDS)
+    readlen = getHeaderComment(bamfile.header, "ReadLength")
+    totalCount = getHeaderComment(bamfile.header, "Total")
     if normalize:
         normalizationFactor = totalCount / 1000000.
 
-    chromList = hitRDS.getChromosomes(fullChrom=useFullchrom)
-    if len(chromList) == 0 and doSplices:
-        chromList = hitRDS.getChromosomes(table="splices", fullChrom=useFullchrom)
-
+    chromList = [chrom for chrom in bamfile.references if chrom != "chrM"]
     chromList.sort()
-
-    if flagRDS:
-        hitRDS.setSynchronousPragma("OFF")        
-
     for rchrom in regionDict:
         if forceRegion and rchrom not in chromList:
             print rchrom
@@ -146,7 +132,7 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
                 labeltoRegionDict[region.label] = (rchrom, region.start, region.stop)
 
     for rchrom in chromList:
-        regionList = []
+        #regionList = []
         if rchrom not in regionDict:
             continue
 
@@ -156,11 +142,6 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
         else:
             fullchrom = "chr%s" % rchrom
 
-        if usePeak:
-            readDict = hitRDS.getReadsDict(chrom=fullchrom, withWeight=True, doMulti=True, findallOptimize=True)
-            rindex = 0
-            dictLen = len(readDict[fullchrom])
-
         for region in regionDict[rchrom]:
             label = region.label
             start = region.start
@@ -168,17 +149,12 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
             regionCount[label] = 0
             labelList.append(label)
             labeltoRegionDict[label] = (rchrom, start, stop)
-            regionList.append((label, fullchrom, start, stop))
+            #regionList.append((label, fullchrom, start, stop))
             if usePeak:
                 readList = []
-                for localIndex in xrange(rindex, dictLen):
-                    read = readDict[fullchrom][localIndex]
-                    if read["start"] < start:
-                        rindex += 1
-                    elif start <= read["start"] <= stop:
-                        readList.append(read)
-                    else:
-                        break
+                for alignedread in bamfile.fetch(fullchrom, start, stop):
+                    weight = 1.0/alignedread.opt('NH')
+                    readList.append({"start": alignedread.pos, "sense": getReadSense(alignedread), "weight": weight})
 
                 if len(readList) < 1:
                     continue
@@ -193,13 +169,7 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
 
                 regionCount[label] += topValue
             else:
-                regionCount[label] += hitRDS.getCounts(fullchrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices)
-
-        if flagRDS:
-            hitRDS.flagReads(regionList, uniqs=doUniqs, multi=doMulti, splices=doSplices)
-
-    if flagRDS:
-        hitRDS.setSynchronousPragma("ON")    
+                regionCount[label] += getRegionReadCounts(bamfile, fullchrom, start, stop, doUniqs=doUniqs, doMulti=doMulti, doSplices=doSplices)
 
     if normalize:
         for label in regionCount:
@@ -235,10 +205,29 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
         outfile.write("\n")
 
     outfile.close()
-    if doCache and flagRDS:
-        hitRDS.saveCacheDB(hitfile)
-
-    writeLog(logfilename, versionString, "returned %d region counts for %s (%.2f M reads)" % (len(labelList), hitfile, totalCount / 1000000.))
+    writeLog(logfilename, versionString, "returned %d region counts (%.2f M reads)" % (len(labelList), totalCount / 1000000.))
+
+
+def getRegionReadCounts(bamfile, chr, start, end, doUniqs=True, doMulti=False, doSplices=False):
+    uniques = 0
+    multis = 0.0
+    uniqueSplice = 0
+    multiSplice = 0.0
+    for alignedread in bamfile.fetch(chr, start, end):
+        readMultiplicity = alignedread.opt('NH')
+        if doSplices and isSpliceEntry(alignedread.cigar):
+            if readMultiplicity == 1 and doUniqs:
+                uniqueSplice += 1
+            elif doMulti:
+                multiSplice += 1.0/readMultiplicity
+        elif readMultiplicity == 1 and doUniqs:
+            uniques += 1
+        elif doMulti:
+            multis += 1.0/readMultiplicity
+
+    totalReads = uniques + multis + uniqueSplice + multiSplice
+
+    return totalReads
 
 
 if __name__ == "__main__":