convert standard analysis pipelines to use bam format natively
[erange.git] / findall.py
index 9dd85df733f86417debf9b060e53f59230110b76..f8907a66c2db6d3b36a4ac9a13872554c1f1d9ca 100755 (executable)
@@ -1,6 +1,6 @@
 """
-    usage: python $ERANGEPATH/findall.py label samplerdsfile regionoutfile
-           [--control controlrdsfile] [--minimum minHits] [--ratio minRatio]
+    usage: python $ERANGEPATH/findall.py label samplebamfile regionoutfile
+           [--control controlbamfile] [--minimum minHits] [--ratio minRatio]
            [--spacing maxSpacing] [--listPeak] [--shift #bp | learn] [--learnFold num]
            [--noshift] [--autoshift] [--reportshift] [--nomulti] [--minPlus fraction]
            [--maxPlus fraction] [--leftPlus fraction] [--minPeak RPM] [--raw]
@@ -50,8 +50,8 @@ import math
 import string
 import optparse
 import operator
-from commoncode import writeLog, findPeak, getBestShiftForRegion, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption
-import ReadDataset
+import pysam
+from commoncode import writeLog, findPeak, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption, isSpliceEntry
 import Region
 
 
@@ -65,7 +65,7 @@ class RegionDirectionError(Exception):
 class RegionFinder():
     def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
                  minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
-                 normalize=True, listPeak=False, reportshift=False, stringency=1.0):
+                 normalize=True, listPeak=False, reportshift=False, stringency=1.0, controlfile=None, doRevBackground=False):
 
         self.statistics = {"index": 0,
                            "total": 0,
@@ -77,8 +77,8 @@ class RegionFinder():
 
         self.regionLabel = label
         self.rnaSettings = False
-        self.controlRDSsize = 1
-        self.sampleRDSsize = 1
+        self.controlRDSsize = 1.0
+        self.sampleRDSsize = 1.0
         self.minRatio = minRatio
         self.minPeak = minPeak
         self.leftPlusRatio = leftPlusRatio
@@ -115,6 +115,10 @@ class RegionFinder():
         self.listPeak = listPeak
         self.reportshift = reportshift
         self.stringency = max(stringency, 1.0)
+        self.controlfile = controlfile
+        self.doControl = self.controlfile is not None
+        self.doPvalue = False
+        self.doRevBackground = doRevBackground
 
 
     def useRNASettings(self, readlen):
@@ -125,7 +129,7 @@ class RegionFinder():
         self.maxSpacing = readlen
 
 
-    def getHeader(self, doPvalue):
+    def getHeader(self):
         if self.normalize:
             countType = "RPM"
         else:
@@ -142,33 +146,33 @@ class RegionFinder():
         if self.reportshift:
             headerFields.append("readShift")
 
-        if doPvalue:
+        if self.doPvalue:
             headerFields.append("pValue")
 
         return string.join(headerFields, "\t")
 
 
-    def printSettings(self, doRevBackground, ptype, doControl, useMulti, doCache, pValueType):
+    def printSettings(self, ptype, useMulti, pValueType):
         print
-        self.printStatusMessages(doRevBackground, ptype, doControl, useMulti)
-        self.printOptionsSummary(useMulti, doCache, pValueType)
+        self.printStatusMessages(ptype, useMulti)
+        self.printOptionsSummary(useMulti, pValueType)
 
 
-    def printStatusMessages(self, doRevBackground, ptype, doControl, useMulti):
+    def printStatusMessages(self, ptype, useMulti):
         if self.shiftValue == "learn":
             print "Will try to learn shift"
 
         if self.normalize:
             print "Normalizing to RPM"
 
-        if doRevBackground:
+        if self.doRevBackground:
             print "Swapping IP and background to calculate FDR"
 
         if ptype != "":
             if ptype in ["NONE", "SELF"]:
                 pass
             elif ptype == "BACK":
-                if doControl and doRevBackground:
+                if self.doControl and self.doRevBackground:
                     pass
                 else:
                     print "must have a control dataset and -revbackground for pValue type 'back'"
@@ -190,9 +194,9 @@ class RegionFinder():
             print "only analyzing reads on the minus strand"
 
 
-    def printOptionsSummary(self, useMulti, doCache, pValueType):
+    def printOptionsSummary(self, useMulti, pValueType):
 
-        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, self.listPeak, not useMulti, doCache)
+        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s " % (self.doDirectionality, self.listPeak, not useMulti)
         print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
         try:
             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
@@ -200,18 +204,18 @@ class RegionFinder():
             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
 
 
-    def getAnalysisDescription(self, hitfile, useMulti, doCache, pValueType, controlfile, doControl):
+    def getAnalysisDescription(self, hitfile, useMulti, pValueType):
 
         description = ["#ERANGE %s" % versionString]
-        if doControl:
-            description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, controlfile, self.controlRDSsize))
+        if self.doControl:
+            description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, self.controlfile, self.controlRDSsize))
         else:
             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
 
         if self.withFlag != "":
             description.append("#restrict to Flag = %s" % self.withFlag)
 
-        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, self.listPeak, not useMulti, doCache))
+        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s" % (self.doDirectionality, self.listPeak, not useMulti))
         description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
         try:
             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
@@ -221,6 +225,31 @@ class RegionFinder():
         return string.join(description, "\n")
 
 
+    def getFooter(self, bestShift):
+        index = self.statistics["index"]
+        mIndex = self.statistics["mIndex"]
+        footerLines = ["#stats:\t%.1f RPM in %d regions" % (self.statistics["total"], index)]
+        if self.doDirectionality:
+            footerLines.append("#\t\t%d additional regions failed directionality filter" % self.statistics["failed"])
+
+        if self.doRevBackground:
+            try:
+                percent = min(100. * (float(mIndex)/index), 100.)
+            except ZeroDivisionError:
+                percent = 0.
+
+            footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, self.statistics["mTotal"], percent))
+
+        if self.shiftValue == "auto" and self.reportshift:
+            
+            footerLines.append("#mode of shift values: %d" % bestShift)
+
+        if self.statistics["badRegionTrim"] > 0:
+            footerLines.append("#%d regions discarded due to trimming problems" % self.statistics["badRegionTrim"])
+
+        return string.join(footerLines, "\n")
+
+
     def updateControlStatistics(self, peak, sumAll, peakScore):
 
         plusRatio = float(peak.numPlus)/peak.numHits
@@ -281,11 +310,11 @@ def main(argv=None):
                                 minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
                                 doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
                                 withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
-                                reportshift=options.reportshift, stringency=options.stringency)
+                                reportshift=options.reportshift, stringency=options.stringency, controlfile=options.controlfile,
+                                doRevBackground=options.doRevBackground)
 
     findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
-            options.cachePages, options.ptype, options.controlfile, options.doRevBackground,
-            options.useMulti, options.combine5p)
+            options.ptype, options.useMulti, options.combine5p)
 
 
 def makeParser():
@@ -363,51 +392,67 @@ def makeParser():
     return parser
 
 
-def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False, cachePages=None,
-            ptype="", controlfile=None, doRevBackground=False, useMulti=True, combine5p=False):
+def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False,
+            ptype="", useMulti=True, combine5p=False):
 
     writeLog(logfilename, versionString, string.join(sys.argv[1:]))
-    doCache = cachePages is not None
-    controlRDS = None
-    doControl = controlfile is not None
-    if doControl:
+    controlBAM = None
+    if regionFinder.doControl:
         print "\ncontrol:" 
-        controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
-        regionFinder.controlRDSsize = len(controlRDS) / 1000000.
+        controlBAM = pysam.Samfile(regionFinder.controlfile, "rb")
+        regionFinder.controlRDSsize = int(getHeaderComment(controlBAM.header, "Total")) / 1000000.
 
     print "\nsample:" 
-    hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
-    regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
-    pValueType = getPValueType(ptype, doControl, doRevBackground)
-    doPvalue = not pValueType == "none"
-    regionFinder.readlen = hitRDS.getReadSize()
+    sampleBAM = pysam.Samfile(hitfile, "rb")
+    regionFinder.sampleRDSsize = int(getHeaderComment(sampleBAM.header, "Total")) / 1000000.
+    pValueType = getPValueType(ptype, regionFinder.doControl, regionFinder.doRevBackground)
+    regionFinder.doPvalue = not pValueType == "none"
+    regionFinder.readlen = int(getHeaderComment(sampleBAM.header, "ReadLength"))
     if rnaSettings:
         regionFinder.useRNASettings(regionFinder.readlen)
 
-    regionFinder.printSettings(doRevBackground, ptype, doControl, useMulti, doCache, pValueType)
+    regionFinder.printSettings(ptype, useMulti, pValueType)
     outfile = open(outfilename, outputMode)
-    header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl)
+    header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, pValueType)
     shiftDict = {}
-    chromosomeList = getChromosomeListToProcess(hitRDS, controlRDS, doControl)
-    for chromosome in chromosomeList:
-        #TODO: QForAli -Really? Use first chr shift value for all of them
+    chromList = getChromosomeListToProcess(sampleBAM, controlBAM)
+    for chromosome in chromList:
+        #TODO: Really? Use first chr shift value for all of them
+        maxSampleCoord = getMaxCoordinate(sampleBAM, chromosome, doMulti=useMulti)
         if regionFinder.shiftValue == "learn":
-            learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
+            regionFinder.shiftValue = learnShift(regionFinder, sampleBAM, maxSampleCoord, chromosome, logfilename, outfilename, outfile, useMulti, controlBAM, combine5p)
 
-        allregions, outregions = findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
-        if doRevBackground:
-            backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti)
-            writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
+        allregions, outregions = findPeakRegions(regionFinder, sampleBAM, maxSampleCoord, chromosome, logfilename, outfilename, outfile, useMulti, controlBAM, combine5p)
+        if regionFinder.doRevBackground:
+            maxControlCoord = getMaxCoordinate(controlBAM, chromosome, doMulti=useMulti)
+            backregions = findBackgroundRegions(regionFinder, sampleBAM, controlBAM, maxControlCoord, chromosome, useMulti)
         else:
-            writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header)
+            backregions = []
+            pValueType = "self"
+
+        writeChromosomeResults(regionFinder, outregions, outfile, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
 
-    footer = getFooter(regionFinder, shiftDict, doRevBackground)
+    try:
+        bestShift = getBestShiftInDict(shiftDict)
+    except ValueError:
+        bestShift = 0
+
+    footer = regionFinder.getFooter(bestShift)
     print footer
     print >> outfile, footer
     outfile.close()
     writeLog(logfilename, versionString, outfilename + footer.replace("\n#"," | ")[:-1])
 
 
+def getHeaderComment(bamHeader, commentKey):
+    for comment in bamHeader["CO"]:
+        fields = comment.split("\t")
+        if fields[0] == commentKey:
+            return fields[1]
+
+    raise KeyError
+
+
 def getPValueType(ptype, doControl, doRevBackground):
     pValueType = "self"
     if ptype in ["NONE", "SELF", "BACK"]:
@@ -424,52 +469,41 @@ def getPValueType(ptype, doControl, doRevBackground):
     return pValueType
 
 
-def openRDSFile(filename, cachePages=None, doCache=False):
-    rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
-    if cachePages > rds.getDefaultCacheSize():
-        rds.setDBcache(cachePages)
-
-    return rds
-
-
-def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl):
-    print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, doCache, pValueType, controlfile, doControl)
-    header = regionFinder.getHeader(doPvalue)
+def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, pValueType):
+    print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, pValueType)
+    header = regionFinder.getHeader()
     print >> outfile, header
 
     return header
 
 
-def getChromosomeListToProcess(hitRDS, controlRDS=None, doControl=False):
-    hitChromList = hitRDS.getChromosomes()
-    if doControl:
-        controlChromList = controlRDS.getChromosomes()
-        chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
+def getChromosomeListToProcess(sampleBAM, controlBAM=None):
+    if controlBAM is not None:
+        chromosomeList = [chrom for chrom in sampleBAM.references if chrom in controlBAM.references and chrom != "chrM"]
     else:
-        chromosomeList = [chrom for chrom in hitChromList if chrom != "chrM"]
+        chromosomeList = [chrom for chrom in sampleBAM.references if chrom != "chrM"]
 
     return chromosomeList
 
 
-def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
-                    outfile, useMulti, doControl, controlRDS, combine5p):
+def findPeakRegions(regionFinder, sampleBAM, maxCoord, chromosome, logfilename, outfilename,
+                    outfile, useMulti, controlBAM, combine5p):
 
     outregions = []
     allregions = []
     print "chromosome %s" % (chromosome)
     previousHit = - 1 * regionFinder.maxSpacing
     readStartPositions = [-1]
-    totalWeight = 0
+    totalWeight = 0.0
     uniqueReadCount = 0
     reads = []
-    numStarts = 0
-    badRegion = False
-    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
-                                  strand=regionFinder.stranded, combine5p=combine5p)
+    numStartsInRegion = 0
 
-    maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
-    for read in hitDict[chromosome]:
-        pos = read["start"]
+    for alignedread in sampleBAM.fetch(chromosome):
+        if doNotProcessRead(alignedread, doMulti=useMulti, strand=regionFinder.stranded, combine5p=combine5p):
+            continue
+
+        pos = alignedread.pos
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
             lastReadPos = readStartPositions[-1]
             lastBasePosition = lastReadPos + regionFinder.readlen - 1
@@ -485,8 +519,8 @@ def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
 
             allregions.append(int(region.numReads))
             regionLength = lastReadPos - region.start
-            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
-                region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
+            if regionPassesCriteria(regionFinder, region.numReads, numStartsInRegion, regionLength):
+                region.foldRatio = getFoldRatio(regionFinder, controlBAM, region.numReads, chromosome, region.start, lastReadPos, useMulti)
 
                 if region.foldRatio >= regionFinder.minRatio:
                     # first pass, with absolute numbers
@@ -495,25 +529,24 @@ def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
                         try:
                             lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
                         except IndexError:
-                            badRegion = True
+                            regionFinder.statistics["badRegionTrim"] += 1
                             continue
 
-                        region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
+                        region.foldRatio = getFoldRatio(regionFinder, controlBAM, region.numReads, chromosome, region.start, lastReadPos, useMulti)
 
-                    # just in case it changed, use latest data
                     try:
                         bestPos = peak.topPos[0]
                         peakScore = peak.smoothArray[bestPos]
                         if regionFinder.normalize:
                             peakScore /= regionFinder.sampleRDSsize
-                    except:
+                    except (IndexError, AttributeError, ZeroDivisionError):
                         continue
 
                     if regionFinder.listPeak:
-                        region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)
+                        region.peakDescription = "%d\t%.1f" % (region.start + bestPos, peakScore)
 
                     if useMulti:
-                        setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
+                        setMultireadPercentage(region, sampleBAM, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
                                                regionFinder.normalize, regionFinder.doTrim)
 
                     region.shift = peak.shift
@@ -530,43 +563,75 @@ def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
                             regionFinder.statistics["failed"] += 1
 
             readStartPositions = []
-            totalWeight = 0
+            totalWeight = 0.0
             uniqueReadCount = 0
             reads = []
-            numStarts = 0
-            if badRegion:
-                badRegion = False
-                regionFinder.statistics["badRegionTrim"] += 1
+            numStartsInRegion = 0
 
         if pos not in readStartPositions:
-            numStarts += 1
+            numStartsInRegion += 1
 
         readStartPositions.append(pos)
-        weight = read["weight"]
+        weight = 1.0/alignedread.opt('NH')
         totalWeight += weight
         if weight == 1.0:
             uniqueReadCount += 1
 
-        reads.append({"start": pos, "sense": read["sense"], "weight": weight})
+        reads.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
         previousHit = pos
 
     return allregions, outregions
 
 
-def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti):
+def getReadSense(read):
+    if read.is_reverse:
+        sense = "-"
+    else:
+        sense = "+"
+
+    return sense
+
+
+def doNotProcessRead(read, doMulti=False, strand="both", combine5p=False):
+    if read.opt('NH') > 1 and not doMulti:
+        return True
+
+    if strand == "+" and read.is_reverse:
+        return True
+
+    if strand == "-" and not read.is_reverse:
+        return True
+        
+    return False
+
+
+def getMaxCoordinate(samfile, chr, doMulti=False):
+    maxCoord = 0
+    for alignedread in samfile.fetch(chr):
+        if alignedread.opt('NH') > 1:
+            if doMulti:
+                maxCoord = max(maxCoord, alignedread.pos)
+        else:
+            maxCoord = max(maxCoord, alignedread.pos)
+
+    return maxCoord
+
+
+def findBackgroundRegions(regionFinder, sampleBAM, controlBAM, maxCoord, chromosome, useMulti):
     #TODO: this is *almost* the same calculation - there are small yet important differences
     print "calculating background..."
     previousHit = - 1 * regionFinder.maxSpacing
     currentHitList = [-1]
-    currentTotalWeight = 0
+    currentTotalWeight = 0.0
     currentReadList = []
     backregions = []
     numStarts = 0
     badRegion = False
-    hitDict = controlRDS.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, findallOptimize=True)
-    maxCoord = controlRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
-    for read in hitDict[chromosome]:
-        pos = read["start"]
+    for alignedread in controlBAM.fetch(chromosome):
+        if doNotProcessRead(alignedread, doMulti=useMulti):
+            continue
+
+        pos = alignedread.pos
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
             lastReadPos = currentHitList[-1]
             lastBasePosition = lastReadPos + regionFinder.readlen - 1
@@ -578,7 +643,7 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti
             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
             regionLength = lastReadPos - region.start
             if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
-                numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
+                numMock = 1. + countReadsInRegion(sampleBAM, chromosome, region.start, lastReadPos, countMulti=useMulti)
                 if regionFinder.normalize:
                     numMock /= regionFinder.sampleRDSsize
 
@@ -595,13 +660,12 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti
                             badRegion = True
                             continue
 
-                        numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
+                        numMock = 1. + countReadsInRegion(sampleBAM, chromosome, region.start, lastReadPos, countMulti=useMulti)
                         if regionFinder.normalize:
                             numMock /= regionFinder.sampleRDSsize
 
                         foldRatio = region.numReads / numMock
 
-                    # just in case it changed, use latest data
                     try:
                         bestPos = peak.topPos[0]
                         peakScore = peak.smoothArray[bestPos]
@@ -618,7 +682,7 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti
                         regionFinder.updateControlStatistics(peak, region.numReads, peakScore)
 
             currentHitList = []
-            currentTotalWeight = 0
+            currentTotalWeight = 0.0
             currentReadList = []
             numStarts = 0
             if badRegion:
@@ -629,32 +693,31 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti
             numStarts += 1
 
         currentHitList.append(pos)
-        weight = read["weight"]
+        weight = 1.0/alignedread.opt('NH')
         currentTotalWeight += weight
-        currentReadList.append({"start": pos, "sense": read["sense"], "weight": weight})
+        currentReadList.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
         previousHit = pos
 
     return backregions
 
 
-def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
-               outfile, useMulti, doControl, controlRDS, combine5p):
-
-    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
-                                  strand=regionFinder.stranded, combine5p=combine5p)
+def learnShift(regionFinder, sampleBAM, maxCoord, chromosome, logfilename, outfilename,
+               outfile, useMulti, controlBAM, combine5p):
 
-    maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
     print "learning shift.... will need at least 30 training sites"
     stringency = regionFinder.stringency
     previousHit = -1 * regionFinder.maxSpacing
     positionList = [-1]
-    totalWeight = 0
+    totalWeight = 0.0
     readList = []
     shiftDict = {}
     count = 0
     numStarts = 0
-    for read in hitDict[chromosome]:
-        pos = read["start"]
+    for alignedread in sampleBAM.fetch(chromosome):
+        if doNotProcessRead(alignedread, doMulti=useMulti, strand=regionFinder.stranded, combine5p=combine5p):
+            continue
+
+        pos = alignedread.pos
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
             if regionFinder.normalize:
                 totalWeight /= regionFinder.sampleRDSsize
@@ -663,22 +726,22 @@ def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
             regionStop = positionList[-1]
             regionLength = regionStop - regionStart
             if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
-                foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl)
+                foldRatio = getFoldRatio(regionFinder, controlBAM, totalWeight, chromosome, regionStart, regionStop, useMulti)
                 if foldRatio >= regionFinder.minRatio:
                     updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
                     count += 1
 
             positionList = []
-            totalWeight = 0
+            totalWeight = 0.0
             readList = []
 
         if pos not in positionList:
             numStarts += 1
 
         positionList.append(pos)
-        weight = read["weight"]
+        weight = 1.0/alignedread.opt('NH')
         totalWeight += weight
-        readList.append({"start": pos, "sense": read["sense"], "weight": weight})
+        readList.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
         previousHit = pos
 
     outline = "#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d\n#number of training examples: %d" % (stringency,
@@ -689,12 +752,14 @@ def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
 
     print outline
     writeLog(logfilename, versionString, outfilename + outline)
-    regionFinder.shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
-    outline = "#picked shiftValue to be %d" % regionFinder.shiftValue
+    shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
+    outline = "#picked shiftValue to be %d" % shiftValue
     print outline
     print >> outfile, outline
     writeLog(logfilename, versionString, outfilename + outline)
 
+    return shiftValue, shiftDict
+
 
 def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
     return abs(pos - previousHit) > maxSpacing or pos == maxCoord
@@ -705,7 +770,7 @@ def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringen
     minNumReadStarts = stringency * regionFinder.minRatio
     minRegionLength = stringency * regionFinder.readlen
 
-    return sumAll >= minTotalReads and numStarts > minNumReadStarts and regionLength > minRegionLength
+    return sumAll >= minTotalReads and numStarts >= minNumReadStarts and regionLength > minRegionLength
 
 
 def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
@@ -758,12 +823,12 @@ def peakEdgeLocated(peak, position, minSignalThresh):
     return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
 
 
-def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl):
+def getFoldRatio(regionFinder, controlBAM, sumAll, chromosome, regionStart, regionStop, useMulti):
     """ Fold ratio calculated is total read weight over control
     """
     #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
-    if doControl:
-        numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
+    if regionFinder.doControl:
+        numMock = 1. + countReadsInRegion(controlBAM, chromosome, regionStart, regionStop, countMulti=useMulti)
         if regionFinder.normalize:
             numMock /= regionFinder.controlRDSsize
 
@@ -774,6 +839,25 @@ def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regi
     return foldRatio
 
 
+def countReadsInRegion(bamfile, chr, start, end, uniqs=True, countMulti=False, countSplices=False):
+    count = 0.0
+    for alignedread in bamfile.fetch(chr, start, end):
+        if alignedread.opt('NH') > 1:
+            if countMulti:
+                if isSpliceEntry(alignedread.cigar):
+                    if countSplices:
+                        count += 1.0/alignedread.opt('NH')
+                else:
+                    count += 1.0/alignedread.opt('NH')
+        elif uniqs:
+            if isSpliceEntry(alignedread.cigar):
+                if countSplices:
+                    count += 1.0
+            else:
+                count += 1.0
+
+    return count
+
 def updateShiftDict(shiftDict, readList, regionStart, regionLength, readlen):
     peak = findPeak(readList, regionStart, regionLength, readlen, doWeight=True, shift="auto")
     try:
@@ -819,9 +903,12 @@ def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRa
     return region
 
 
-def setMultireadPercentage(region, hitRDS, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
+def setMultireadPercentage(region, sampleBAM, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
     if doTrim:
-        sumMulti = hitRDS.getMultiCount(chromosome, region.start, lastReadPos)
+        sumMulti = 0.0
+        for alignedread in sampleBAM.fetch(chromosome, region.start, lastReadPos):
+            if alignedread.opt('NH') > 1:
+                sumMulti += 1.0/alignedread.opt('NH')
     else:
         sumMulti = currentTotalWeight - currentUniqueCount
 
@@ -856,25 +943,33 @@ def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plus
             raise RegionDirectionError
 
 
-def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
-                                allregions, header):
-
-    writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
-                           allregions, header, backregions=[], pValueType="self")
-
-
-def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
+def writeChromosomeResults(regionFinder, outregions, outfile, shiftDict,
                            allregions, header, backregions=[], pValueType="none"):
 
     print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
-    if doPvalue:
+    if regionFinder.doPvalue:
         if pValueType == "self":
             poissonmean = calculatePoissonMean(allregions)
         else:
             poissonmean = calculatePoissonMean(backregions)
 
     print header
-    writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=regionFinder.reportshift, shiftDict=shiftDict)
+    for region in outregions:
+        if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
+            try:
+                shiftDict[region.shift] += 1
+            except KeyError:
+                shiftDict[region.shift] = 1
+
+        outline = getRegionString(region, regionFinder.reportshift)
+
+        # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
+        if regionFinder.doPvalue:
+            pValue = calculatePValue(int(region.numReads), poissonmean)
+            outline += "\t%1.2g" % pValue
+
+        print outline
+        print >> outfile, outline
 
 
 def calculatePoissonMean(dataList):
@@ -890,29 +985,8 @@ def calculatePoissonMean(dataList):
     return poissonmean
 
 
-def writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=0, reportshift=False, shiftDict={}):
-    for region in outregions:
-        if shiftValue == "auto" and reportshift:
-            try:
-                shiftDict[region.shift] += 1
-            except KeyError:
-                shiftDict[region.shift] = 1
-
-        outline = getRegionString(region, reportshift)
-
-        # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
-        if doPvalue:
-            sumAll = int(region.numReads)
-            pValue = calculatePValue(sumAll, poissonmean)
-            outline += "\t%1.2g" % pValue
-
-        print outline
-        print >> outfile, outline
-
-
 def calculatePValue(sum, poissonmean):
     pValue = math.exp(-poissonmean)
-    #TODO: 798: DeprecationWarning: integer argument expected, got float - for i in xrange(sum)
     for i in xrange(sum):
         pValue *= poissonmean
         pValue /= i+1
@@ -929,34 +1003,9 @@ def getRegionString(region, reportShift):
     return outline
 
 
-def getFooter(regionFinder, shiftDict, doRevBackground):
-    index = regionFinder.statistics["index"]
-    mIndex = regionFinder.statistics["mIndex"]
-    footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
-    if regionFinder.doDirectionality:
-        footerLines.append("#\t\t%d additional regions failed directionality filter" % regionFinder.statistics["failed"])
-
-    if doRevBackground:
-        try:
-            percent = min(100. * (float(mIndex)/index), 100.)
-        except ZeroDivisionError:
-            percent = 0.
-
-        footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))
-
-    if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
-        bestShift = getBestShiftInDict(shiftDict)
-        footerLines.append("#mode of shift values: %d" % bestShift)
-
-    if regionFinder.statistics["badRegionTrim"] > 0:
-        footerLines.append("#%d regions discarded due to trimming problems" % regionFinder.statistics["badRegionTrim"])
-
-    return string.join(footerLines, "\n")
-
-
 def getBestShiftInDict(shiftDict):
     return max(shiftDict.iteritems(), key=operator.itemgetter(1))[0]
 
 
 if __name__ == "__main__":
-    main(sys.argv)
\ No newline at end of file
+    main(sys.argv)