first pass cleanup of cistematic/genomes; change bamPreprocessing
[erange.git] / regionCounts.py
index 0104cc2f4bf97bb64797a1b6ce418a8b7e1b4528..e11508ea65e7a61614afb971f88fdb8f00824da7 100755 (executable)
@@ -9,10 +9,13 @@ try:
 except:
     print 'psyco not running'
 
-import sys, string, optparse
-from commoncode import readDataset, getMergedRegions, findPeak, writeLog
+import sys
+import string
+import optparse
+import pysam
+from commoncode import getMergedRegions, findPeak, writeLog, getConfigParser, getConfigOption, getConfigIntOption, getConfigBoolOption, getHeaderComment, isSpliceEntry, getReadSense
 
-versionString = "%prog: version 3.9"
+versionString = "regionCounts: version 3.10"
 print versionString
 
 def main(argv=None):
@@ -21,6 +24,27 @@ def main(argv=None):
 
     usage = "usage: python %prog regionfile rdsfile outfilename [options]"
 
+    parser = getParser(usage)
+    (options, args) = parser.parse_args(argv[1:])
+
+    if len(args) < 3:
+        print usage
+        sys.exit(1)
+
+    regionfilename = args[0]
+    bamfilename =  args[1]
+    outfilename = args[2]
+
+    bamfile = pysam.Samfile(bamfilename, "rb")
+
+    regionCounts(regionfilename, bamfile, outfilename, options.flagRDS, options.cField,
+                 options.useFullchrom, options.normalize, options.padregion,
+                 options.mergeregion, options.merging, options.doUniqs, options.doMulti,
+                 options.doSplices, options.usePeak, options.cachePages, options.logfilename,
+                 options.doRPKM, options.doLength, options.forceRegion)
+
+
+def getParser(usage):
     parser = optparse.OptionParser(usage=usage)
     parser.add_option("--markRDS", action="store_true", dest="flagRDS")
     parser.add_option("--chromField", type="int", dest="cField")
@@ -38,30 +62,36 @@ def main(argv=None):
     parser.add_option("--rpkm", action="store_true", dest="doRPKM")
     parser.add_option("--length", action="store_true", dest="doLength")
     parser.add_option("--force", action="store_true", dest="forceRegion")
-    parser.set_defaults(flagRDS=False, cField=1, useFullchrom=False, normalize=True,
-                        padregion=0, mergeregion=0, merging=True, doUniqs=True,
-                        doMulti=True, doSplices=False, usePeak=False, cachePages=-1,
-                        logfilename="regionCounts.log", doRPKM=False, doLength=False,
-                        forceRegion=False)
-
-    (options, args) = parser.parse_args(argv[1:])
-
-    if len(args) < 3:
-        print usage
-        sys.exit(1)
-
-    regionfilename = args[0]
-    hitfile =  args[1]
-    outfilename = args[2]
-
-    regionCounts(regionfilename, hitfile, outfilename, options.flagRDS, options.cField,
-                 options.useFullchrom, options.normalize, options.padregion,
-                 options.mergeregion, options.merging, options.doUniqs, options.doMulti,
-                 options.doSplices, options.usePeak, options.cachePages, options.logfilename,
-                 options.doRPKM, options.doLength, options.forceRegion)
-
 
-def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
+    configParser = getConfigParser()
+    section = "regionCounts"
+    flagRDS = getConfigBoolOption(configParser, section, "flagRDS", False)
+    cField = getConfigIntOption(configParser, section, "cField", 1)
+    useFullchrom = getConfigBoolOption(configParser, section, "useFullchrom", False)
+    normalize = getConfigBoolOption(configParser, section, "normalize", True)
+    padregion = getConfigIntOption(configParser, section, "padregion", 0)
+    mergeregion = getConfigIntOption(configParser, section, "mergeregion", 0)
+    merging = getConfigBoolOption(configParser, section, "merging", True)
+    doUniqs = getConfigBoolOption(configParser, section, "doUniqs", True)
+    doMulti = getConfigBoolOption(configParser, section, "doMulti", True)
+    doSplices = getConfigBoolOption(configParser, section, "doSplices", False)
+    usePeak = getConfigBoolOption(configParser, section, "usePeak", False)
+    cachePages = getConfigIntOption(configParser, section, "cachePages", -1)
+    logfilename = getConfigOption(configParser, section, "logfilename", "regionCounts.log")
+    doRPKM = getConfigBoolOption(configParser, section, "doRPKM", False)
+    doLength = getConfigBoolOption(configParser, section, "doLength", False)
+    forceRegion = getConfigBoolOption(configParser, section, "forceRegion", False)
+
+    parser.set_defaults(flagRDS=flagRDS, cField=cField, useFullchrom=useFullchrom, normalize=normalize,
+                        padregion=padregion, mergeregion=mergeregion, merging=merging, doUniqs=doUniqs,
+                        doMulti=doMulti, doSplices=doSplices, usePeak=usePeak, cachePages=cachePages,
+                        logfilename=logfilename, doRPKM=doRPKM, doLength=doLength,
+                        forceRegion=forceRegion)
+
+    return parser
+
+
+def regionCounts(regionfilename, bamfile, outfilename, flagRDS=False, cField=1,
                  useFullchrom=False, normalize=True, padregion=0, mergeregion=0,
                  merging=True, doUniqs=True, doMulti=True, doSplices=False, usePeak=False,
                  cachePages=-1, logfilename="regionCounts.log", doRPKM=False, doLength=False,
@@ -71,11 +101,6 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
     print "merging regions closer than %d bp" % mergeregion
     print "will use peak values"
 
-    if cachePages != -1:
-        doCache = True
-    else:
-        doCache = False
-
     normalize = True
     doRPKM = False
     if doRPKM == True:
@@ -91,34 +116,23 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
     labeltoRegionDict = {}
     regionCount = {}
 
-    hitRDS = readDataset(hitfile, verbose=True, cache=doCache)
-    readlen = hitRDS.getReadSize()
-    if cachePages > hitRDS.getDefaultCacheSize():
-        hitRDS.setDBcache(cachePages)
-
-    totalCount = len(hitRDS)
+    readlen = getHeaderComment(bamfile.header, "ReadLength")
+    totalCount = getHeaderComment(bamfile.header, "Total")
     if normalize:
         normalizationFactor = totalCount / 1000000.
 
-    chromList = hitRDS.getChromosomes(fullChrom=useFullchrom)
-    if len(chromList) == 0 and doSplices:
-        chromList = hitRDS.getChromosomes(table="splices", fullChrom=useFullchrom)
-
+    chromList = [chrom for chrom in bamfile.references if chrom != "chrM"]
     chromList.sort()
-
-    if flagRDS:
-        hitRDS.setSynchronousPragma("OFF")        
-
     for rchrom in regionDict:
         if forceRegion and rchrom not in chromList:
             print rchrom
-            for (label, start, stop, length) in regionDict[rchrom]:
-                regionCount[label] = 0
-                labelList.append(label)
-                labeltoRegionDict[label] = (rchrom, start, stop)
+            for region in regionDict[rchrom]:
+                regionCount[region.label] = 0
+                labelList.append(region.label)
+                labeltoRegionDict[region.label] = (rchrom, region.start, region.stop)
 
     for rchrom in chromList:
-        regionList = []
+        #regionList = []
         if rchrom not in regionDict:
             continue
 
@@ -128,54 +142,34 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
         else:
             fullchrom = "chr%s" % rchrom
 
-        if usePeak:
-            readDict = hitRDS.getReadsDict(chrom=fullchrom, withWeight=True, doMulti=True, findallOptimize=True)
-            rindex = 0
-            dictLen = len(readDict[fullchrom])
-
-        for (label, start, stop, length) in regionDict[rchrom]:
+        for region in regionDict[rchrom]:
+            label = region.label
+            start = region.start
+            stop = region.stop
             regionCount[label] = 0
             labelList.append(label)
             labeltoRegionDict[label] = (rchrom, start, stop)
-
-        if useFullchrom:
-            fullchrom = rchrom
-        else:
-            fullchrom = "chr%s" % rchrom
-
-        for (label, rstart, rstop, length) in regionDict[rchrom]:
-            regionList.append((label, fullchrom, rstart, rstop))
+            #regionList.append((label, fullchrom, start, stop))
             if usePeak:
                 readList = []
-                for localIndex in xrange(rindex, dictLen):
-                    read = readDict[fullchrom][localIndex]
-                    if read[0] < rstart:
-                        rindex += 1
-                    elif rstart <= read[0] <= rstop:
-                        readList.append(read)
-                    else:
-                        break
+                for alignedread in bamfile.fetch(fullchrom, start, stop):
+                    weight = 1.0/alignedread.opt('NH')
+                    readList.append({"start": alignedread.pos, "sense": getReadSense(alignedread), "weight": weight})
 
                 if len(readList) < 1:
                     continue
 
                 readList.sort()
-                (topPos, numHits, smoothArray, numPlus) = findPeak(readList, rstart, rstop - rstart, readlen, doWeight=True)
+                peak = findPeak(readList, start, stop - start, readlen, doWeight=True)
                 try:
-                    topValue = smoothArray[topPos[0]]
+                    topValue = peak.smoothArray[peak.topPos[0]]
                 except:
-                    print "problem with %s %s" % (str(topPos), str(smoothArray))
+                    print "problem with %s %s" % (str(peak.topPos), str(peak.smoothArray))
                     continue
 
                 regionCount[label] += topValue
             else:
-                regionCount[label] += hitRDS.getCounts(fullchrom, rstart, rstop, uniqs=doUniqs, multi=doMulti, splices=doSplices)
-
-        if flagRDS:
-            hitRDS.flagReads(regionList, uniqs=doUniqs, multi=doMulti, splices=doSplices)
-
-    if flagRDS:
-        hitRDS.setSynchronousPragma("ON")    
+                regionCount[label] += getRegionReadCounts(bamfile, fullchrom, start, stop, doUniqs=doUniqs, doMulti=doMulti, doSplices=doSplices)
 
     if normalize:
         for label in regionCount:
@@ -211,10 +205,29 @@ def regionCounts(regionfilename, hitfile, outfilename, flagRDS=False, cField=1,
         outfile.write("\n")
 
     outfile.close()
-    if doCache and flagRDS:
-        hitRDS.saveCacheDB(hitfile)
-
-    writeLog(logfilename, versionString, "returned %d region counts for %s (%.2f M reads)" % (len(labelList), hitfile, totalCount / 1000000.))
+    writeLog(logfilename, versionString, "returned %d region counts (%.2f M reads)" % (len(labelList), totalCount / 1000000.))
+
+
+def getRegionReadCounts(bamfile, chr, start, end, doUniqs=True, doMulti=False, doSplices=False):
+    uniques = 0
+    multis = 0.0
+    uniqueSplice = 0
+    multiSplice = 0.0
+    for alignedread in bamfile.fetch(chr, start, end):
+        readMultiplicity = alignedread.opt('NH')
+        if doSplices and isSpliceEntry(alignedread.cigar):
+            if readMultiplicity == 1 and doUniqs:
+                uniqueSplice += 1
+            elif doMulti:
+                multiSplice += 1.0/readMultiplicity
+        elif readMultiplicity == 1 and doUniqs:
+            uniques += 1
+        elif doMulti:
+            multis += 1.0/readMultiplicity
+
+    totalReads = uniques + multis + uniqueSplice + multiSplice
+
+    return totalReads
 
 
 if __name__ == "__main__":