rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
[erange.git] / findall.py
index 663d7e075d5d30535da6eed6547d4da2b4970be2..f2c70432b1ccc477ca3ccee55b708a0e3e82f4bf 100755 (executable)
@@ -60,11 +60,12 @@ print versionString
 
 class RegionDirectionError(Exception):
     pass
-
+            
 
 class RegionFinder():
-    def __init__(self, minRatio, minPeak, minPlusRatio, maxPlusRatio, leftPlusRatio, strandfilter, minHits, trimValue, doTrim, doDirectionality,
-                 shiftValue, maxSpacing):
+    def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
+                 minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
+                 normalize=True, listPeak=False, reportshift=False, stringency=1.0):
 
         self.statistics = {"index": 0,
                            "total": 0,
@@ -74,6 +75,8 @@ class RegionFinder():
                            "badRegionTrim": 0
         }
 
+        self.regionLabel = label
+        self.rnaSettings = False
         self.controlRDSsize = 1
         self.sampleRDSsize = 1
         self.minRatio = minRatio
@@ -107,17 +110,89 @@ class RegionFinder():
 
         self.shiftValue = shiftValue
         self.maxSpacing = maxSpacing
+        self.withFlag = withFlag
+        self.normalize = normalize
+        self.listPeak = listPeak
+        self.reportshift = reportshift
+        self.stringency = max(stringency, 1.0)
 
 
-    def useRNASettings(self):
+    def useRNASettings(self, readlen):
+        self.rnaSettings = True
         self.shiftValue = 0
         self.doTrim = False
         self.doDirectionality = False
+        self.maxSpacing = readlen
+
+
+    def getHeader(self, doPvalue):
+        if self.normalize:
+            countType = "RPM"
+        else:
+            countType = "COUNT"
+
+        headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
+
+        if self.doDirectionality:
+            headerFields.append("plus%\tleftPlus%")
+
+        if self.listPeak:
+            headerFields.append("peakPos\tpeakHeight")
+
+        if self.reportshift:
+            headerFields.append("readShift")
+
+        if doPvalue:
+            headerFields.append("pValue")
+
+        return string.join(headerFields, "\t")
+
+
+    def printSettings(self, doRevBackground, ptype, doControl, useMulti, doCache, pValueType):
+        print
+        self.printStatusMessages(doRevBackground, ptype, doControl, useMulti)
+        self.printOptionsSummary(useMulti, doCache, pValueType)
+
+
+    def printStatusMessages(self, doRevBackground, ptype, doControl, useMulti):
+        if self.shiftValue == "learn":
+            print "Will try to learn shift"
+
+        if self.normalize:
+            print "Normalizing to RPM"
+
+        if doRevBackground:
+            print "Swapping IP and background to calculate FDR"
+
+        if ptype != "":
+            if ptype in ["NONE", "SELF"]:
+                pass
+            elif ptype == "BACK":
+                if doControl and doRevBackground:
+                    pass
+                else:
+                    print "must have a control dataset and -revbackground for pValue type 'back'"
+            else:
+                print "could not use pValue type : %s" % ptype
+
+        if self.withFlag != "":
+            print "restrict to flag = %s" % self.withFlag
+
+        if not useMulti:
+            print "using unique reads only"
+
+        if self.rnaSettings:
+            print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
+
+        if self.strandfilter == "plus":
+            print "only analyzing reads on the plus strand"
+        elif self.strandfilter == "minus":
+            print "only analyzing reads on the minus strand"
 
 
-    def printOptionsSummary(self, listPeak, useMulti, doCache, pValueType):
+    def printOptionsSummary(self, useMulti, doCache, pValueType):
 
-        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, listPeak, not useMulti, doCache)
+        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, self.listPeak, not useMulti, doCache)
         print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
         try:
             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
@@ -125,8 +200,7 @@ class RegionFinder():
             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
 
 
-    def getAnalysisDescription(self, hitfile, controlfile, doControl,
-                                 withFlag, listPeak, useMulti, doCache, pValueType):
+    def getAnalysisDescription(self, hitfile, useMulti, doCache, pValueType, controlfile, doControl):
 
         description = ["#ERANGE %s" % versionString]
         if doControl:
@@ -134,10 +208,10 @@ class RegionFinder():
         else:
             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
 
-        if withFlag != "":
-            description.append("#restrict to Flag = %s" % withFlag)
+        if self.withFlag != "":
+            description.append("#restrict to Flag = %s" % self.withFlag)
 
-        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, listPeak, not useMulti, doCache))
+        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, self.listPeak, not useMulti, doCache))
         description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
         try:
             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
@@ -147,6 +221,22 @@ class RegionFinder():
         return string.join(description, "\n")
 
 
+    def updateControlStatistics(self, peak, sumAll, peakScore):
+
+        plusRatio = float(peak.numPlus)/peak.numHits
+        if peakScore >= self.minPeak and self.minPlusRatio <= plusRatio <= self.maxPlusRatio:
+            if self.doDirectionality:
+                if self.leftPlusRatio < peak.numLeft / peak.numPlus:
+                    self.statistics["mIndex"] += 1
+                    self.statistics["mTotal"] += sumAll
+                else:
+                    self.statistics["failed"] += 1
+            else:
+                # we have a region, but didn't check for directionality
+                self.statistics["mIndex"] += 1
+                self.statistics["mTotal"] += sumAll
+
+
 def usage():
     print __doc__
 
@@ -186,13 +276,16 @@ def main(argv=None):
     else:
         outputMode = "w"
 
-    findall(factor, hitfile, outfilename, options.minHits, options.minRatio, options.maxSpacing, options.listPeak, shiftValue,
-            options.stringency, options.reportshift,
-            options.minPlusRatio, options.maxPlusRatio, options.leftPlusRatio, options.minPeak,
-            options.normalize, options.logfilename, options.withFlag, options.doDirectionality,
-            options.trimValue, options.doTrim, outputMode, options.rnaSettings,
-            options.cachePages, options.ptype, options.controlfile, options.doRevBackground, options.useMulti,
-            options.strandfilter, options.combine5p)
+    regionFinder = RegionFinder(factor, minRatio=options.minRatio, minPeak=options.minPeak, minPlusRatio=options.minPlusRatio,
+                                maxPlusRatio=options.maxPlusRatio, leftPlusRatio=options.leftPlusRatio, strandfilter=options.strandfilter,
+                                minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
+                                doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
+                                withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
+                                reportshift=options.reportshift, stringency=options.stringency)
+
+    findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
+            options.cachePages, options.ptype, options.controlfile, options.doRevBackground,
+            options.useMulti, options.combine5p)
 
 
 def makeParser():
@@ -270,72 +363,44 @@ def makeParser():
     return parser
 
 
-def findall(factor, hitfile, outfilename, minHits=4.0, minRatio=4.0, maxSpacing=50, listPeak=False, shiftValue=0,
-            stringency=4.0, reportshift=False, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, minPeak=0.5,
-            normalize=True, logfilename="findall.log", withFlag="", doDirectionality=True, trimValue=0.1, doTrim=True,
-            outputMode="w", rnaSettings=False, cachePages=None, ptype="", controlfile=None, doRevBackground=False,
-            useMulti=True, strandfilter="", combine5p=False):
-
-    regionFinder = RegionFinder(minRatio, minPeak, minPlusRatio, maxPlusRatio, leftPlusRatio, strandfilter, minHits, trimValue,
-                                doTrim, doDirectionality, shiftValue, maxSpacing)
-
-    doControl = controlfile is not None
-    pValueType = getPValueType(ptype, doControl, doRevBackground)
-    doPvalue = not pValueType == "none"
-
-    if rnaSettings:
-        regionFinder.useRNASettings()
+def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False, cachePages=None,
+            ptype="", controlfile=None, doRevBackground=False, useMulti=True, combine5p=False):
 
+    writeLog(logfilename, versionString, string.join(sys.argv[1:]))
     doCache = cachePages is not None
-    printStatusMessages(regionFinder.shiftValue, normalize, doRevBackground, ptype, doControl, withFlag, useMulti, rnaSettings, regionFinder.strandfilter)
     controlRDS = None
+    doControl = controlfile is not None
     if doControl:
         print "\ncontrol:" 
         controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
+        regionFinder.controlRDSsize = len(controlRDS) / 1000000.
 
     print "\nsample:" 
     hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
-
-    print
+    regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
+    pValueType = getPValueType(ptype, doControl, doRevBackground)
+    doPvalue = not pValueType == "none"
     regionFinder.readlen = hitRDS.getReadSize()
     if rnaSettings:
-        regionFinder.maxSpacing = regionFinder.readlen
-
-    writeLog(logfilename, versionString, string.join(sys.argv[1:]))
-    regionFinder.printOptionsSummary(listPeak, useMulti, doCache, pValueType)
-
-    regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
-    if doControl:
-        regionFinder.controlRDSsize = len(controlRDS) / 1000000.
+        regionFinder.useRNASettings(regionFinder.readlen)
 
+    regionFinder.printSettings(doRevBackground, ptype, doControl, useMulti, doCache, pValueType)
     outfile = open(outfilename, outputMode)
-    print >> outfile, regionFinder.getAnalysisDescription(hitfile, controlfile, doControl,
-                                                          withFlag, listPeak, useMulti, doCache, pValueType)
-
-    header = getHeader(normalize, regionFinder.doDirectionality, listPeak, reportshift, doPvalue)
-    print >> outfile, header
-    if minRatio < minPeak:
-        minPeak = minRatio
-
+    header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl)
     shiftDict = {}
-    hitChromList = hitRDS.getChromosomes()
-    stringency = max(stringency, 1.0)
-    chromosomeList = getChromosomeListToProcess(hitChromList, controlRDS, doControl)
+    chromosomeList = getChromosomeListToProcess(hitRDS, controlRDS, doControl)
     for chromosome in chromosomeList:
-        allregions, outregions = findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, outfilename,
-                                                 outfile, stringency, normalize, useMulti, doControl, withFlag, combine5p,
-                                                 factor, listPeak)
+        if regionFinder.shiftValue == "learn":
+            learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
 
+        allregions, outregions = findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
         if doRevBackground:
-            #TODO: this is *almost* the same calculation - there are small yet important differences
-            backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normalize, useMulti, factor)
-            writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
-                                   allregions, header, backregions=backregions, pValueType=pValueType)
+            backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti)
+            writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
         else:
-            writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
-                                        allregions, header)
+            writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header)
 
-    footer = getFooter(regionFinder, shiftDict, reportshift, doRevBackground)
+    footer = getFooter(regionFinder, shiftDict, doRevBackground)
     print footer
     print >> outfile, footer
     outfile.close()
@@ -358,42 +423,6 @@ def getPValueType(ptype, doControl, doRevBackground):
     return pValueType
 
 
-def printStatusMessages(shiftValue, normalize, doRevBackground, ptype, doControl, withFlag, useMulti, rnaSettings, strandfilter):
-    if shiftValue == "learn":
-        print "Will try to learn shift"
-
-    if normalize:
-        print "Normalizing to RPM"
-
-    if doRevBackground:
-        print "Swapping IP and background to calculate FDR"
-
-    if ptype != "":
-        if ptype in ["NONE", "SELF"]:
-            pass
-        elif ptype == "BACK":
-            if doControl and doRevBackground:
-                pass
-            else:
-                print "must have a control dataset and -revbackground for pValue type 'back'"
-        else:
-            print "could not use pValue type : %s" % ptype
-
-    if withFlag != "":
-        print "restrict to flag = %s" % withFlag
-
-    if not useMulti:
-        print "using unique reads only"
-
-    if rnaSettings:
-        print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
-
-    if strandfilter == "plus":
-        print "only analyzing reads on the plus strand"
-    elif strandfilter == "minus":
-        print "only analyzing reads on the minus strand"
-
-
 def openRDSFile(filename, cachePages=None, doCache=False):
     rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
     if cachePages > rds.getDefaultCacheSize():
@@ -402,30 +431,16 @@ def openRDSFile(filename, cachePages=None, doCache=False):
     return rds
 
 
-def getHeader(normalize, doDirectionality, listPeak, reportshift, doPvalue):
-    if normalize:
-        countType = "RPM"
-    else:
-        countType = "COUNT"
-
-    headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
-
-    if doDirectionality:
-        headerFields.append("plus%\tleftPlus%")
-
-    if listPeak:
-        headerFields.append("peakPos\tpeakHeight")
-
-    if reportshift:
-        headerFields.append("readShift")
-
-    if doPvalue:
-        headerFields.append("pValue")
+def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl):
+    print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, doCache, pValueType, controlfile, doControl)
+    header = regionFinder.getHeader(doPvalue)
+    print >> outfile, header
 
-    return string.join(headerFields, "\t")
+    return header
 
 
-def getChromosomeListToProcess(hitChromList, controlRDS, doControl):
+def getChromosomeListToProcess(hitRDS, controlRDS=None, doControl=False):
+    hitChromList = hitRDS.getChromosomes()
     if doControl:
         controlChromList = controlRDS.getChromosomes()
         chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
@@ -435,8 +450,8 @@ def getChromosomeListToProcess(hitChromList, controlRDS, doControl):
     return chromosomeList
 
 
-def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, outfilename,
-                    outfile, stringency, normalize, useMulti, doControl, withFlag, combine5p, factor, listPeak):
+def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
+                    outfile, useMulti, doControl, controlRDS, combine5p):
 
     outregions = []
     allregions = []
@@ -448,14 +463,10 @@ def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, o
     reads = []
     numStarts = 0
     badRegion = False
-    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
+    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
                                   strand=regionFinder.stranded, combine5p=combine5p)
 
     maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
-    if regionFinder.shiftValue == "learn":
-        learnShift(regionFinder, hitDict, controlRDS, chromosome, maxCoord, logfilename, outfilename,
-                                outfile, numStarts, stringency, normalize, useMulti, doControl)
-
     for read in hitDict[chromosome]:
         pos = read["start"]
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
@@ -463,54 +474,52 @@ def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, o
             lastBasePosition = lastReadPos + regionFinder.readlen - 1
             newRegionIndex = regionFinder.statistics["index"] + 1
             if regionFinder.doDirectionality:
-                region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=factor, numReads=totalWeight)
+                region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel,
+                                                  numReads=totalWeight)
             else:
-                region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=factor, numReads=totalWeight)
+                region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel, numReads=totalWeight)
 
-            if normalize:
+            if regionFinder.normalize:
                 region.numReads /= regionFinder.sampleRDSsize
 
             allregions.append(int(region.numReads))
             regionLength = lastReadPos - region.start
-            if regionPassesCriteria(region.numReads, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen):
-                region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos,
-                                                useMulti, doControl, normalize)
+            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
+                region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
 
                 if region.foldRatio >= regionFinder.minRatio:
                     # first pass, with absolute numbers
                     peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
                     if regionFinder.doTrim:
                         try:
-                            lastReadPos = trimRegion(region, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.readlen, regionFinder.doDirectionality,
-                                                     normalize, regionFinder.sampleRDSsize)
+                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
                         except IndexError:
                             badRegion = True
                             continue
 
-                        region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos,
-                                                        useMulti, doControl, normalize)
+                        region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
 
                     # just in case it changed, use latest data
                     try:
                         bestPos = peak.topPos[0]
                         peakScore = peak.smoothArray[bestPos]
-                        if normalize:
+                        if regionFinder.normalize:
                             peakScore /= regionFinder.sampleRDSsize
                     except:
                         continue
 
-                    if listPeak:
+                    if regionFinder.listPeak:
                         region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)
 
                     if useMulti:
-                        setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos, normalize, regionFinder.doTrim)
+                        setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
+                                               regionFinder.normalize, regionFinder.doTrim)
 
                     region.shift = peak.shift
                     # check that we still pass threshold
                     regionLength = lastReadPos - region.start
                     plusRatio = float(peak.numPlus)/peak.numHits
-                    if regionAndPeakPass(region, regionFinder.minHits, regionFinder.minRatio, regionLength, regionFinder.readlen, peakScore, regionFinder.minPeak,
-                                         regionFinder.minPlusRatio, regionFinder.maxPlusRatio, plusRatio):
+                    if regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
                         try:
                             updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
                             regionFinder.statistics["index"] += 1
@@ -543,8 +552,8 @@ def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, o
     return allregions, outregions
 
 
-def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normalize, useMulti, factor):
-
+def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti):
+    #TODO: this is *almost* the same calculation - there are small yet important differences
     print "calculating background..."
     previousHit = - 1 * regionFinder.maxSpacing
     currentHitList = [-1]
@@ -560,16 +569,16 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normaliz
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
             lastReadPos = currentHitList[-1]
             lastBasePosition = lastReadPos + regionFinder.readlen - 1
-            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=factor, numReads=currentTotalWeight)
-            if normalize:
+            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
+            if regionFinder.normalize:
                 region.numReads /= regionFinder.controlRDSsize
 
             backregions.append(int(region.numReads))
-            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=factor, numReads=currentTotalWeight)
+            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
             regionLength = lastReadPos - region.start
-            if regionPassesCriteria(region.numReads, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen):
+            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
                 numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
-                if normalize:
+                if regionFinder.normalize:
                     numMock /= regionFinder.sampleRDSsize
 
                 foldRatio = region.numReads / numMock
@@ -580,14 +589,13 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normaliz
 
                     if regionFinder.doTrim:
                         try:
-                            lastReadPos = trimRegion(region, peak, lastReadPos, 20., currentReadList, regionFinder.readlen, regionFinder.doDirectionality,
-                                                     normalize, regionFinder.controlRDSsize)
+                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, 20., currentReadList, regionFinder.controlRDSsize)
                         except IndexError:
                             badRegion = True
                             continue
 
                         numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
-                        if normalize:
+                        if regionFinder.normalize:
                             numMock /= regionFinder.sampleRDSsize
 
                         foldRatio = region.numReads / numMock
@@ -600,13 +608,13 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normaliz
                         continue
 
                     # normalize to RPM
-                    if normalize:
+                    if regionFinder.normalize:
                         peakScore /= regionFinder.controlRDSsize
 
                     # check that we still pass threshold
                     regionLength = lastReadPos - region.start
-                    if regionPassesCriteria(region.numReads, regionFinder.minHits, foldRatio, regionFinder.minRatio, regionLength, regionFinder.readlen):
-                        updateControlStatistics(regionFinder, peak, region.numReads, peakScore)
+                    if regionPassesCriteria(regionFinder, region.numReads, foldRatio, regionLength):
+                        regionFinder.updateControlStatistics(peak, region.numReads, peakScore)
 
             currentHitList = []
             currentTotalWeight = 0
@@ -628,27 +636,33 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normaliz
     return backregions
 
 
-def learnShift(regionFinder, hitDict, controlRDS, chromosome, maxCoord, logfilename, outfilename,
-               outfile, numStarts, stringency, normalize, useMulti, doControl):
+def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
+               outfile, useMulti, doControl, controlRDS, combine5p):
+
+    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
+                                  strand=regionFinder.stranded, combine5p=combine5p)
 
+    maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
     print "learning shift.... will need at least 30 training sites"
+    stringency = regionFinder.stringency
     previousHit = -1 * regionFinder.maxSpacing
     positionList = [-1]
     totalWeight = 0
     readList = []
     shiftDict = {}
     count = 0
+    numStarts = 0
     for read in hitDict[chromosome]:
         pos = read["start"]
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
-            if normalize:
+            if regionFinder.normalize:
                 totalWeight /= regionFinder.sampleRDSsize
 
             regionStart = positionList[0]
             regionStop = positionList[-1]
             regionLength = regionStop - regionStart
-            if regionPassesCriteria(totalWeight, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen, stringency=stringency):
-                foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl, normalize)
+            if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
+                foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl)
                 if foldRatio >= regionFinder.minRatio:
                     updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
                     count += 1
@@ -685,15 +699,15 @@ def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
     return abs(pos - previousHit) > maxSpacing or pos == maxCoord
 
 
-def regionPassesCriteria(sumAll, minHits, numStarts, minRatio, regionLength, readlen, stringency=1):
-    return sumAll >= stringency * minHits and numStarts > stringency * minRatio and regionLength > stringency * readlen
+def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringency=1):
+    return sumAll >= stringency * regionFinder.minHits and numStarts > stringency * regionFinder.minRatio and regionLength > stringency * regionFinder.readlen
 
 
-def trimRegion(region, peak, regionStop, trimValue, currentReadList, readlen, doDirectionality, normalize, hitRDSsize):
+def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
     bestPos = peak.topPos[0]
     peakScore = peak.smoothArray[bestPos]
-    if normalize:
-        peakScore /= hitRDSsize
+    if regionFinder.normalize:
+        peakScore /= totalReadCount
 
     minSignalThresh = trimValue * peakScore
     start = findStartEdgePosition(peak, minSignalThresh)
@@ -702,18 +716,21 @@ def trimRegion(region, peak, regionStop, trimValue, currentReadList, readlen, do
 
     regionStop = region.start + stop
     region.start += start
-    trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, readlen, doWeight=True, leftPlus=doDirectionality, shift=peak.shift)
+
+    trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, regionFinder.readlen, doWeight=True,
+                           leftPlus=regionFinder.doDirectionality, shift=peak.shift)
+
     peak.numPlus = trimmedPeak.numPlus
     peak.numLeftPlus = trimmedPeak.numLeftPlus
     peak.topPos = trimmedPeak.topPos
     peak.smoothArray = trimmedPeak.smoothArray
 
     region.numReads = trimmedPeak.numHits
-    if normalize:
-        region.numReads /= hitRDSsize
+    if regionFinder.normalize:
+        region.numReads /= totalReadCount
 
-    region.stop = regionStop + readlen - 1
-                                
+    region.stop = regionStop + regionFinder.readlen - 1
+                          
     return regionStop
 
 
@@ -736,10 +753,13 @@ def peakEdgeLocated(peak, position, minSignalThresh):
     return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
 
 
-def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl, normalize):
+def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl):
+    """ Fold ratio calculated is total read weight over control
+    """
+    #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
     if doControl:
         numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
-        if normalize:
+        if regionFinder.normalize:
             numMock /= regionFinder.controlRDSsize
 
         foldRatio = sumAll / numMock
@@ -770,22 +790,6 @@ def getShiftValue(shiftDict, count, logfilename, outfilename):
     return shiftValue
 
 
-def updateControlStatistics(regionFinder, peak, sumAll, peakScore):
-
-    plusRatio = float(peak.numPlus)/peak.numHits
-    if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
-        if regionFinder.doDirectionality:
-            if regionFinder.leftPlusRatio < peak.numLeft / peak.numPlus:
-                regionFinder.statistics["mIndex"] += 1
-                regionFinder.statistics["mTotal"] += sumAll
-            else:
-                regionFinder.statistics["failed"] += 1
-        else:
-            # we have a region, but didn't check for directionality
-            regionFinder.statistics["mIndex"] += 1
-            regionFinder.statistics["mTotal"] += sumAll
-
-
 def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
               peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
               numPlus, plusRatio):
@@ -828,18 +832,16 @@ def setMultireadPercentage(region, hitRDS, hitRDSsize, currentTotalWeight, curre
     region.multiP = multiP
 
 
-def regionAndPeakPass(region, minHits, minRatio, regionLength, readlen, peakScore, minPeak, minPlusRatio, maxPlusRatio, plusRatio):
+def regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
     regionPasses = False
-    if regionPassesCriteria(region.numReads, minHits, region.foldRatio, minRatio, regionLength, readlen):
-        if peakScore >= minPeak and minPlusRatio <= plusRatio <= maxPlusRatio:
+    if regionPassesCriteria(regionFinder, region.numReads, region.foldRatio, regionLength):
+        if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
             regionPasses = True
 
     return regionPasses
 
 
-def updateRegion(region,
-                 doDirectionality, leftPlusRatio, numLeft,
-                 numPlus, plusRatio):
+def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plusRatio):
 
     if doDirectionality:
         if leftPlusRatio < numLeft / numPlus:
@@ -849,14 +851,14 @@ def updateRegion(region,
             raise RegionDirectionError
 
 
-def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
+def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                                 allregions, header):
 
-    writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
+    writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                            allregions, header, backregions=[], pValueType="self")
 
 
-def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
+def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                            allregions, header, backregions=[], pValueType="none"):
 
     print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
@@ -867,7 +869,7 @@ def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportsh
             poissonmean = calculatePoissonMean(backregions)
 
     print header
-    writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=reportshift, shiftDict=shiftDict)
+    writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=regionFinder.reportshift, shiftDict=shiftDict)
 
 
 def calculatePoissonMean(dataList):
@@ -905,7 +907,6 @@ def writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=0, repor
 
 def calculatePValue(sum, poissonmean):
     pValue = math.exp(-poissonmean)
-    #TODO: 798: DeprecationWarning: integer argument expected, got float - for i in xrange(sum)
     for i in xrange(sum):
         pValue *= poissonmean
         pValue /= i+1
@@ -922,7 +923,7 @@ def getRegionString(region, reportShift):
     return outline
 
 
-def getFooter(regionFinder, shiftDict, reportshift, doRevBackground):
+def getFooter(regionFinder, shiftDict, doRevBackground):
     index = regionFinder.statistics["index"]
     mIndex = regionFinder.statistics["mIndex"]
     footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
@@ -937,7 +938,7 @@ def getFooter(regionFinder, shiftDict, reportshift, doRevBackground):
 
         footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))
 
-    if regionFinder.shiftValue == "auto" and reportshift:
+    if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
         bestShift = getBestShiftInDict(shiftDict)
         footerLines.append("#mode of shift values: %d" % bestShift)