rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
[erange.git] / findall.py
index d20608f1997e2e6a0fd84d9169c78025485e4b57..f2c70432b1ccc477ca3ccee55b708a0e3e82f4bf 100755 (executable)
@@ -49,14 +49,194 @@ import sys
 import math
 import string
 import optparse
 import math
 import string
 import optparse
+import operator
 from commoncode import writeLog, findPeak, getBestShiftForRegion, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption
 import ReadDataset
 import Region
 
 
 from commoncode import writeLog, findPeak, getBestShiftForRegion, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption
 import ReadDataset
 import Region
 
 
-versionString = "findall: version 3.2"
+versionString = "findall: version 3.2.1"
 print versionString
 
 print versionString
 
+class RegionDirectionError(Exception):
+    pass
+            
+
+class RegionFinder():
+    def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
+                 minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
+                 normalize=True, listPeak=False, reportshift=False, stringency=1.0):
+
+        self.statistics = {"index": 0,
+                           "total": 0,
+                           "mIndex": 0,
+                           "mTotal": 0,
+                           "failed": 0,
+                           "badRegionTrim": 0
+        }
+
+        self.regionLabel = label
+        self.rnaSettings = False
+        self.controlRDSsize = 1
+        self.sampleRDSsize = 1
+        self.minRatio = minRatio
+        self.minPeak = minPeak
+        self.leftPlusRatio = leftPlusRatio
+        self.stranded = "both"
+        if strandfilter == "plus":
+            self.stranded = "+"
+            minPlusRatio = 0.9
+            maxPlusRatio = 1.0
+        elif strandfilter == "minus":
+            self.stranded = "-"
+            minPlusRatio = 0.0
+            maxPlusRatio = 0.1
+
+        if minRatio < minPeak:
+            self.minPeak = minRatio
+
+        self.minPlusRatio = minPlusRatio
+        self.maxPlusRatio = maxPlusRatio
+        self.strandfilter = strandfilter
+        self.minHits = minHits
+        self.trimValue = trimValue
+        self.doTrim = doTrim
+        self.doDirectionality = doDirectionality
+
+        if self.doTrim:
+            self.trimString = string.join(["%2.1f" % (100. * self.trimValue), "%"], "")
+        else:
+            self.trimString = "none"
+
+        self.shiftValue = shiftValue
+        self.maxSpacing = maxSpacing
+        self.withFlag = withFlag
+        self.normalize = normalize
+        self.listPeak = listPeak
+        self.reportshift = reportshift
+        self.stringency = max(stringency, 1.0)
+
+
+    def useRNASettings(self, readlen):
+        self.rnaSettings = True
+        self.shiftValue = 0
+        self.doTrim = False
+        self.doDirectionality = False
+        self.maxSpacing = readlen
+
+
+    def getHeader(self, doPvalue):
+        if self.normalize:
+            countType = "RPM"
+        else:
+            countType = "COUNT"
+
+        headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
+
+        if self.doDirectionality:
+            headerFields.append("plus%\tleftPlus%")
+
+        if self.listPeak:
+            headerFields.append("peakPos\tpeakHeight")
+
+        if self.reportshift:
+            headerFields.append("readShift")
+
+        if doPvalue:
+            headerFields.append("pValue")
+
+        return string.join(headerFields, "\t")
+
+
+    def printSettings(self, doRevBackground, ptype, doControl, useMulti, doCache, pValueType):
+        print
+        self.printStatusMessages(doRevBackground, ptype, doControl, useMulti)
+        self.printOptionsSummary(useMulti, doCache, pValueType)
+
+
+    def printStatusMessages(self, doRevBackground, ptype, doControl, useMulti):
+        if self.shiftValue == "learn":
+            print "Will try to learn shift"
+
+        if self.normalize:
+            print "Normalizing to RPM"
+
+        if doRevBackground:
+            print "Swapping IP and background to calculate FDR"
+
+        if ptype != "":
+            if ptype in ["NONE", "SELF"]:
+                pass
+            elif ptype == "BACK":
+                if doControl and doRevBackground:
+                    pass
+                else:
+                    print "must have a control dataset and -revbackground for pValue type 'back'"
+            else:
+                print "could not use pValue type : %s" % ptype
+
+        if self.withFlag != "":
+            print "restrict to flag = %s" % self.withFlag
+
+        if not useMulti:
+            print "using unique reads only"
+
+        if self.rnaSettings:
+            print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
+
+        if self.strandfilter == "plus":
+            print "only analyzing reads on the plus strand"
+        elif self.strandfilter == "minus":
+            print "only analyzing reads on the minus strand"
+
+
+    def printOptionsSummary(self, useMulti, doCache, pValueType):
+
+        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, self.listPeak, not useMulti, doCache)
+        print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
+        try:
+            print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
+        except:
+            print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
+
+
+    def getAnalysisDescription(self, hitfile, useMulti, doCache, pValueType, controlfile, doControl):
+
+        description = ["#ERANGE %s" % versionString]
+        if doControl:
+            description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, controlfile, self.controlRDSsize))
+        else:
+            description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
+
+        if self.withFlag != "":
+            description.append("#restrict to Flag = %s" % self.withFlag)
+
+        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, self.listPeak, not useMulti, doCache))
+        description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
+        try:
+            description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
+        except:
+            description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
+
+        return string.join(description, "\n")
+
+
+    def updateControlStatistics(self, peak, sumAll, peakScore):
+
+        plusRatio = float(peak.numPlus)/peak.numHits
+        if peakScore >= self.minPeak and self.minPlusRatio <= plusRatio <= self.maxPlusRatio:
+            if self.doDirectionality:
+                if self.leftPlusRatio < peak.numLeft / peak.numPlus:
+                    self.statistics["mIndex"] += 1
+                    self.statistics["mTotal"] += sumAll
+                else:
+                    self.statistics["failed"] += 1
+            else:
+                # we have a region, but didn't check for directionality
+                self.statistics["mIndex"] += 1
+                self.statistics["mTotal"] += sumAll
+
+
 def usage():
     print __doc__
 
 def usage():
     print __doc__
 
@@ -76,20 +256,43 @@ def main(argv=None):
     hitfile = args[1]
     outfilename = args[2]
 
     hitfile = args[1]
     outfilename = args[2]
 
-    findall(factor, hitfile, outfilename, options.minHits, options.minRatio, options.maxSpacing, options.listPeak, options.shift,
-            options.stringency, options.noshift, options.autoshift, options.reportshift,
-            options.minPlusRatio, options.maxPlusRatio, options.leftPlusRatio, options.minPeak,
-            options.normalize, options.logfilename, options.withFlag, options.doDirectionality,
-            options.trimValue, options.doTrim, options.doAppend, options.rnaSettings,
-            options.cachePages, options.ptype, options.mockfile, options.doRevBackground, options.noMulti,
-            options.strandfilter, options.combine5p)
+    shiftValue = 0
+
+    if options.autoshift:
+        shiftValue = "auto"
+
+    if options.shift is not None:
+        try:
+            shiftValue = int(options.shift)
+        except ValueError:
+            if options.shift == "learn":
+                shiftValue = "learn"
+
+    if options.noshift:
+        shiftValue = 0
+
+    if options.doAppend:
+        outputMode = "a"
+    else:
+        outputMode = "w"
+
+    regionFinder = RegionFinder(factor, minRatio=options.minRatio, minPeak=options.minPeak, minPlusRatio=options.minPlusRatio,
+                                maxPlusRatio=options.maxPlusRatio, leftPlusRatio=options.leftPlusRatio, strandfilter=options.strandfilter,
+                                minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
+                                doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
+                                withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
+                                reportshift=options.reportshift, stringency=options.stringency)
+
+    findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
+            options.cachePages, options.ptype, options.controlfile, options.doRevBackground,
+            options.useMulti, options.combine5p)
 
 
 def makeParser():
     usage = __doc__
 
     parser = optparse.OptionParser(usage=usage)
 
 
 def makeParser():
     usage = __doc__
 
     parser = optparse.OptionParser(usage=usage)
-    parser.add_option("--control", dest="mockfile")
+    parser.add_option("--control", dest="controlfile")
     parser.add_option("--minimum", type="float", dest="minHits")
     parser.add_option("--ratio", type="float", dest="minRatio")
     parser.add_option("--spacing", type="int", dest="maxSpacing")
     parser.add_option("--minimum", type="float", dest="minHits")
     parser.add_option("--ratio", type="float", dest="minRatio")
     parser.add_option("--spacing", type="int", dest="maxSpacing")
@@ -99,7 +302,7 @@ def makeParser():
     parser.add_option("--noshift", action="store_true", dest="noShift")
     parser.add_option("--autoshift", action="store_true", dest="autoshift")
     parser.add_option("--reportshift", action="store_true", dest="reportshift")
     parser.add_option("--noshift", action="store_true", dest="noShift")
     parser.add_option("--autoshift", action="store_true", dest="autoshift")
     parser.add_option("--reportshift", action="store_true", dest="reportshift")
-    parser.add_option("--nomulti", action="store_true", dest="noMulti")
+    parser.add_option("--nomulti", action="store_false", dest="useMulti")
     parser.add_option("--minPlus", type="float", dest="minPlusRatio")
     parser.add_option("--maxPlus", type="float", dest="maxPlusRatio")
     parser.add_option("--leftPlus", type="float", dest="leftPlusRatio")
     parser.add_option("--minPlus", type="float", dest="minPlusRatio")
     parser.add_option("--maxPlus", type="float", dest="maxPlusRatio")
     parser.add_option("--leftPlus", type="float", dest="leftPlusRatio")
@@ -137,16 +340,16 @@ def makeParser():
     logfilename = getConfigOption(configParser, section, "logfilename", "findall.log")
     withFlag = getConfigOption(configParser, section, "withFlag", "")
     doDirectionality = getConfigBoolOption(configParser, section, "doDirectionality", True)
     logfilename = getConfigOption(configParser, section, "logfilename", "findall.log")
     withFlag = getConfigOption(configParser, section, "withFlag", "")
     doDirectionality = getConfigBoolOption(configParser, section, "doDirectionality", True)
-    trimValue = getConfigOption(configParser, section, "trimValue", None)
+    trimValue = getConfigFloatOption(configParser, section, "trimValue", 0.1)
     doTrim = getConfigBoolOption(configParser, section, "doTrim", True)
     doAppend = getConfigBoolOption(configParser, section, "doAppend", False)
     rnaSettings = getConfigBoolOption(configParser, section, "rnaSettings", False)
     cachePages = getConfigOption(configParser, section, "cachePages", None)
     doTrim = getConfigBoolOption(configParser, section, "doTrim", True)
     doAppend = getConfigBoolOption(configParser, section, "doAppend", False)
     rnaSettings = getConfigBoolOption(configParser, section, "rnaSettings", False)
     cachePages = getConfigOption(configParser, section, "cachePages", None)
-    ptype = getConfigOption(configParser, section, "ptype", None)
-    mockfile = getConfigOption(configParser, section, "mockfile", None)
+    ptype = getConfigOption(configParser, section, "ptype", "")
+    controlfile = getConfigOption(configParser, section, "controlfile", None)
     doRevBackground = getConfigBoolOption(configParser, section, "doRevBackground", False)
     doRevBackground = getConfigBoolOption(configParser, section, "doRevBackground", False)
-    noMulti = getConfigBoolOption(configParser, section, "noMulti", False)
-    strandfilter = getConfigOption(configParser, section, "strandfilter", None)
+    useMulti = getConfigBoolOption(configParser, section, "useMulti", True)
+    strandfilter = getConfigOption(configParser, section, "strandfilter", "")
     combine5p = getConfigBoolOption(configParser, section, "combine5p", False)
 
     parser.set_defaults(minHits=minHits, minRatio=minRatio, maxSpacing=maxSpacing, listPeak=listPeak, shift=shift,
     combine5p = getConfigBoolOption(configParser, section, "combine5p", False)
 
     parser.set_defaults(minHits=minHits, minRatio=minRatio, maxSpacing=maxSpacing, listPeak=listPeak, shift=shift,
@@ -154,609 +357,600 @@ def makeParser():
                         minPlusRatio=minPlusRatio, maxPlusRatio=maxPlusRatio, leftPlusRatio=leftPlusRatio, minPeak=minPeak,
                         normalize=normalize, logfilename=logfilename, withFlag=withFlag, doDirectionality=doDirectionality,
                         trimValue=trimValue, doTrim=doTrim, doAppend=doAppend, rnaSettings=rnaSettings,
                         minPlusRatio=minPlusRatio, maxPlusRatio=maxPlusRatio, leftPlusRatio=leftPlusRatio, minPeak=minPeak,
                         normalize=normalize, logfilename=logfilename, withFlag=withFlag, doDirectionality=doDirectionality,
                         trimValue=trimValue, doTrim=doTrim, doAppend=doAppend, rnaSettings=rnaSettings,
-                        cachePages=cachePages, ptype=ptype, mockfile=mockfile, doRevBackground=doRevBackground, noMulti=noMulti,
+                        cachePages=cachePages, ptype=ptype, controlfile=controlfile, doRevBackground=doRevBackground, useMulti=useMulti,
                         strandfilter=strandfilter, combine5p=combine5p)
 
     return parser
 
 
                         strandfilter=strandfilter, combine5p=combine5p)
 
     return parser
 
 
-def findall(factor, hitfile, outfilename, minHits=4.0, minRatio=4.0, maxSpacing=50, listPeak=False, shift=None,
-            stringency=4.0, noshift=False, autoshift=False, reportshift=False,
-            minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, minPeak=0.5,
-            normalize=True, logfilename="findall.log", withFlag="", doDirectionality=True,
-            trimValue=None, doTrim=True, doAppend=False, rnaSettings=False,
-            cachePages=None, ptype=None, mockfile=None, doRevBackground=False, noMulti=False,
-            strandfilter=None, combine5p=False):
+def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False, cachePages=None,
+            ptype="", controlfile=None, doRevBackground=False, useMulti=True, combine5p=False):
 
 
-    shiftValue = determineShiftValue(autoshift, shift, noshift, rnaSettings)
+    writeLog(logfilename, versionString, string.join(sys.argv[1:]))
+    doCache = cachePages is not None
+    controlRDS = None
+    doControl = controlfile is not None
+    if doControl:
+        print "\ncontrol:" 
+        controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
+        regionFinder.controlRDSsize = len(controlRDS) / 1000000.
 
 
-    if trimValue is not None:
-        trimValue = float(trimValue) / 100.
-        trimString = "%2.1f%s" % ((100. * trimValue), "%")
-    else:
-        trimValue = 0.1
-        trimString = "10%"
+    print "\nsample:" 
+    hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
+    regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
+    pValueType = getPValueType(ptype, doControl, doRevBackground)
+    doPvalue = not pValueType == "none"
+    regionFinder.readlen = hitRDS.getReadSize()
+    if rnaSettings:
+        regionFinder.useRNASettings(regionFinder.readlen)
 
 
-    if not doTrim:
-        trimString = "none"
+    regionFinder.printSettings(doRevBackground, ptype, doControl, useMulti, doCache, pValueType)
+    outfile = open(outfilename, outputMode)
+    header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl)
+    shiftDict = {}
+    chromosomeList = getChromosomeListToProcess(hitRDS, controlRDS, doControl)
+    for chromosome in chromosomeList:
+        if regionFinder.shiftValue == "learn":
+            learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
+
+        allregions, outregions = findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
+        if doRevBackground:
+            backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti)
+            writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
+        else:
+            writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header)
 
 
-    if doRevBackground:
-        print "Swapping IP and background to calculate FDR"
-        pValueType = "back"
+    footer = getFooter(regionFinder, shiftDict, doRevBackground)
+    print footer
+    print >> outfile, footer
+    outfile.close()
+    writeLog(logfilename, versionString, outfilename + footer.replace("\n#"," | ")[:-1])
 
 
-    doControl = False
-    if mockfile is not None:
-        doControl = True
 
 
-    doPvalue = True
-    if ptype is not None:
-        ptype = ptype.upper()
+def getPValueType(ptype, doControl, doRevBackground):
+    pValueType = "self"
+    if ptype in ["NONE", "SELF", "BACK"]:
         if ptype == "NONE":
         if ptype == "NONE":
-            doPvalue = False
             pValueType = "none"
             pValueType = "none"
-            p = 1
-            poissonmean = 0
         elif ptype == "SELF":
             pValueType = "self"
         elif ptype == "BACK":
             if doControl and doRevBackground:
                 pValueType = "back"
         elif ptype == "SELF":
             pValueType = "self"
         elif ptype == "BACK":
             if doControl and doRevBackground:
                 pValueType = "back"
-            else:
-                print "must have a control dataset and -revbackground for pValue type 'back'"
-        else:
-            print "could not use pValue type : %s" % ptype
-    else:
-        pValueType = "self"
-
-    if cachePages is not None:
-        doCache = True
-    else:
-        doCache = False
-        cachePages = -1
-
-    if withFlag != "":
-        print "restrict to flag = %s" % withFlag
+    elif doRevBackground:
+        pValueType = "back"
 
 
-    useMulti = True
-    if noMulti:
-        print "using unique reads only"
-        useMulti = False
+    return pValueType
 
 
-    if rnaSettings:
-        print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
-        doTrim = False
-        doDirectionality = False
 
 
-    stranded = ""
-    if strandfilter is not None:
-        if strandfilter == "plus":
-            stranded = "+"
-            minPlusRatio = 0.9
-            maxPlusRatio = 1.0
-            print "only analyzing reads on the plus strand"
-        elif strandfilter == "minus":
-            stranded = "-"
-            minPlusRatio = 0.0
-            maxPlusRatio = 0.1
-            print "only analyzing reads on the minus strand"
-
-    stringency = max(stringency, 1.0)
-    writeLog(logfilename, versionString, string.join(sys.argv[1:]))
-    if doControl:
-        print "\ncontrol:" 
-        mockRDS = ReadDataset.ReadDataset(mockfile, verbose=True, cache=doCache)
+def openRDSFile(filename, cachePages=None, doCache=False):
+    rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
+    if cachePages > rds.getDefaultCacheSize():
+        rds.setDBcache(cachePages)
 
 
-        if cachePages > mockRDS.getDefaultCacheSize():
-            mockRDS.setDBcache(cachePages)
+    return rds
 
 
-    print "\nsample:" 
-    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
-    readlen = hitRDS.getReadSize()
-    if rnaSettings:
-        maxSpacing = readlen
 
 
-    if cachePages > hitRDS.getDefaultCacheSize():
-        hitRDS.setDBcache(cachePages)
+def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl):
+    print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, doCache, pValueType, controlfile, doControl)
+    header = regionFinder.getHeader(doPvalue)
+    print >> outfile, header
 
 
-    if doAppend:
-        fileMode = "a"
-    else:
-        fileMode = "w"
+    return header
 
 
-    outfile = open(outfilename, fileMode)
 
 
-    outfile.write("#ERANGE %s\n" % versionString)
+def getChromosomeListToProcess(hitRDS, controlRDS=None, doControl=False):
+    hitChromList = hitRDS.getChromosomes()
     if doControl:
     if doControl:
-        mockRDSsize = len(mockRDS) / 1000000.
-        controlSampleString = "\t%s (%.1f M reads)" % (mockfile, mockRDSsize)
+        controlChromList = controlRDS.getChromosomes()
+        chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
     else:
     else:
-        controlSampleString = " none"
+        chromosomeList = [chrom for chrom in hitChromList if chrom != "chrM"]
 
 
-    hitRDSsize = len(hitRDS) / 1000000.
-    outfile.write("#enriched sample:\t%s (%.1f M reads)\n#control sample:%s\n" % (hitfile, hitRDSsize, controlSampleString))
+    return chromosomeList
 
 
-    if withFlag != "":
-        outfile.write("#restrict to Flag = %s\n" % withFlag)
 
 
-    print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (doDirectionality, listPeak, noMulti, doCache)
-    print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (maxSpacing, minHits, minRatio, minPeak, trimString, stranded)
-    print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (minPlusRatio, maxPlusRatio, leftPlusRatio, str(shiftValue), pValueType)
+def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
+                    outfile, useMulti, doControl, controlRDS, combine5p):
 
 
-    outfile.write("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s\n" % (doDirectionality, listPeak, noMulti, doCache))
-    outfile.write("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s\n" % (maxSpacing, minHits, minRatio, minPeak, trimString, stranded))
-    outfile.write("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s\n" % (minPlusRatio, maxPlusRatio, leftPlusRatio, str(shiftValue), pValueType))
-    if normalize:
-        print "Normalizing to RPM"
-        countLabel = "RPM"
-    else:
-        countLabel = "COUNT"
+    outregions = []
+    allregions = []
+    print "chromosome %s" % (chromosome)
+    previousHit = - 1 * regionFinder.maxSpacing
+    readStartPositions = [-1]
+    totalWeight = 0
+    uniqueReadCount = 0
+    reads = []
+    numStarts = 0
+    badRegion = False
+    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
+                                  strand=regionFinder.stranded, combine5p=combine5p)
 
 
-    headerList = ["#regionID\tchrom\tstart\tstop", countLabel, "fold\tmulti%"]
-    if doDirectionality:
-        headerList.append("plus%\tleftPlus%")
+    maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
+    for read in hitDict[chromosome]:
+        pos = read["start"]
+        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
+            lastReadPos = readStartPositions[-1]
+            lastBasePosition = lastReadPos + regionFinder.readlen - 1
+            newRegionIndex = regionFinder.statistics["index"] + 1
+            if regionFinder.doDirectionality:
+                region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel,
+                                                  numReads=totalWeight)
+            else:
+                region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel, numReads=totalWeight)
 
 
-    if listPeak:
-        headerList.append("peakPos\tpeakHeight")
+            if regionFinder.normalize:
+                region.numReads /= regionFinder.sampleRDSsize
 
 
-    if reportshift:
-        headerList.append("readShift")
+            allregions.append(int(region.numReads))
+            regionLength = lastReadPos - region.start
+            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
+                region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
 
 
-    if doPvalue:
-        headerList.append("pValue")
+                if region.foldRatio >= regionFinder.minRatio:
+                    # first pass, with absolute numbers
+                    peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
+                    if regionFinder.doTrim:
+                        try:
+                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
+                        except IndexError:
+                            badRegion = True
+                            continue
 
 
-    headline = string.join(headerList, "\t")
-    print >> outfile, headline
+                        region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
 
 
-    statistics = {"index": 0,
-                  "total": 0,
-                  "mIndex": 0,
-                  "mTotal": 0,
-                  "failed": 0
-    }
+                    # just in case it changed, use latest data
+                    try:
+                        bestPos = peak.topPos[0]
+                        peakScore = peak.smoothArray[bestPos]
+                        if regionFinder.normalize:
+                            peakScore /= regionFinder.sampleRDSsize
+                    except:
+                        continue
 
 
-    if minRatio < minPeak:
-        minPeak = minRatio
+                    if regionFinder.listPeak:
+                        region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)
 
 
-    hitChromList = hitRDS.getChromosomes()
-    if doControl:
-        mockChromList = mockRDS.getChromosomes()
+                    if useMulti:
+                        setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
+                                               regionFinder.normalize, regionFinder.doTrim)
 
 
-    if normalize:
-        if doControl:
-            mockSampleSize = mockRDSsize
-
-        hitSampleSize = hitRDSsize
-
-    hitChromList.sort()
-
-    for chromosome in hitChromList:
-        if doNotProcessChromosome(chromosome, doControl, mockChromList):
-            continue
-
-        print "chromosome %s" % (chromosome)
-        hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=withFlag, withWeight=True,
-                                      doMulti=useMulti, findallOptimize=True, strand=stranded,
-                                      combine5p=combine5p)
-        maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
-        if shiftValue == "learn":
-            shiftValue = learnShift(hitDict, hitSampleSize, mockRDS, chromosome, doControl, useMulti, normalize,
-                                    mockSampleSize, minRatio, maxSpacing, maxCoord, stringency, readlen, minHits,
-                                    logfilename, outfile, outfilename)
-
-        regionStats, allRegionWeights, outregions = locateRegions(hitRDS, hitSampleSize, mockRDS, mockSampleSize,
-                                                                  chromosome, useMulti, normalize, maxSpacing,
-                                                                  doDirectionality, doTrim, minHits, minRatio,
-                                                                  readlen, shiftValue, minPeak, minPlusRatio,
-                                                                  maxPlusRatio, leftPlusRatio, listPeak, noMulti,
-                                                                  doControl, factor, trimValue, outputRegionList=True)
-
-        statistics["index"] += regionStats["index"]
-        statistics["total"] += regionStats["total"]
-        statistics["failed"] += regionStats["failed"]
-        if not doRevBackground:
-            if doPvalue:
-                p, poissonmean = calculatePValue(allRegionWeights)
-
-            print headline
-            shiftModeValue = writeRegionsToFile(outfile, outregions, doPvalue, p, poissonmean, reportshift, shiftValue)
-            continue
-
-        #now do background swapping the two samples around
-        print "calculating background..."
-        backgroundTrimValue = 1/20.
-        backgroundRegionStats, backgroundRegionWeights = locateRegions(mockRDS, mockSampleSize, hitRDS, hitSampleSize,
-                                                                       chromosome, useMulti, normalize, maxSpacing,
-                                                                       doDirectionality, doTrim, minHits, minRatio,
-                                                                       readlen, shiftValue, minPeak, minPlusRatio,
-                                                                       maxPlusRatio, leftPlusRatio, listPeak, noMulti,
-                                                                       doControl, factor, backgroundTrimValue)
-
-        statistics["mIndex"] += backgroundRegionStats["index"]
-        statistics["mTotal"] += backgroundRegionStats["total"]
-        statistics["failed"] += backgroundRegionStats["failed"]
-        print statistics["mIndex"], statistics["mTotal"]
-        if doPvalue:
-            if pValueType == "self":
-                p, poissonmean = calculatePValue(allRegionWeights)
-            else:
-                p, poissonmean = calculatePValue(backgroundRegionWeights)
+                    region.shift = peak.shift
+                    # check that we still pass threshold
+                    regionLength = lastReadPos - region.start
+                    plusRatio = float(peak.numPlus)/peak.numHits
+                    if regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
+                        try:
+                            updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
+                            regionFinder.statistics["index"] += 1
+                            outregions.append(region)
+                            regionFinder.statistics["total"] += region.numReads
+                        except RegionDirectionError:
+                            regionFinder.statistics["failed"] += 1
+
+            readStartPositions = []
+            totalWeight = 0
+            uniqueReadCount = 0
+            reads = []
+            numStarts = 0
+            if badRegion:
+                badRegion = False
+                regionFinder.statistics["badRegionTrim"] += 1
 
 
-        print headline
-        shiftModeValue = writeRegionsToFile(outfile, outregions, doPvalue, p, poissonmean, reportshift, shiftValue)
+        if pos not in readStartPositions:
+            numStarts += 1
 
 
-    footer = getFooter(statistics, doDirectionality, doRevBackground, shiftValue, reportshift, shiftModeValue)
-    print footer
-    outfile.write(footer)
-    outfile.close()
+        readStartPositions.append(pos)
+        weight = read["weight"]
+        totalWeight += weight
+        if weight == 1.0:
+            uniqueReadCount += 1
 
 
-    writeLog(logfilename, versionString, "%s%s" % (outfilename, footer.replace("\n#", " | ")))
+        reads.append({"start": pos, "sense": read["sense"], "weight": weight})
+        previousHit = pos
 
 
+    return allregions, outregions
 
 
-def determineShiftValue(autoshift, shift, noshift, rnaSettings):
-    shiftValue = 0
-    if autoshift:
-        shiftValue = "auto"
 
 
-    if shift is not None:
-        try:
-            shiftValue = int(shift)
-        except ValueError:
-            if shift == "learn":
-                shiftValue = "learn"
-                print "Will try to learn shift"
+def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti):
+    #TODO: this is *almost* the same calculation - there are small yet important differences
+    print "calculating background..."
+    previousHit = - 1 * regionFinder.maxSpacing
+    currentHitList = [-1]
+    currentTotalWeight = 0
+    currentReadList = []
+    backregions = []
+    numStarts = 0
+    badRegion = False
+    hitDict = controlRDS.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, findallOptimize=True)
+    maxCoord = controlRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
+    for read in hitDict[chromosome]:
+        pos = read["start"]
+        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
+            lastReadPos = currentHitList[-1]
+            lastBasePosition = lastReadPos + regionFinder.readlen - 1
+            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
+            if regionFinder.normalize:
+                region.numReads /= regionFinder.controlRDSsize
+
+            backregions.append(int(region.numReads))
+            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
+            regionLength = lastReadPos - region.start
+            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
+                numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
+                if regionFinder.normalize:
+                    numMock /= regionFinder.sampleRDSsize
+
+                foldRatio = region.numReads / numMock
+                if foldRatio >= regionFinder.minRatio:
+                    # first pass, with absolute numbers
+                    peak = findPeak(currentReadList, region.start, lastReadPos - region.start, regionFinder.readlen, doWeight=True,
+                                    leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
 
 
-    if noshift or rnaSettings:
-        shiftValue = 0
+                    if regionFinder.doTrim:
+                        try:
+                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, 20., currentReadList, regionFinder.controlRDSsize)
+                        except IndexError:
+                            badRegion = True
+                            continue
 
 
-    return shiftValue
+                        numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
+                        if regionFinder.normalize:
+                            numMock /= regionFinder.sampleRDSsize
 
 
+                        foldRatio = region.numReads / numMock
 
 
-def doNotProcessChromosome(chromosome, doControl, mockChromList):
-    skipChromosome = False
-    if chromosome == "chrM":
-        skipChromosome = True
+                    # just in case it changed, use latest data
+                    try:
+                        bestPos = peak.topPos[0]
+                        peakScore = peak.smoothArray[bestPos]
+                    except IndexError:
+                        continue
 
 
-    if doControl and (chromosome not in mockChromList):
-        skipChromosome = True
+                    # normalize to RPM
+                    if regionFinder.normalize:
+                        peakScore /= regionFinder.controlRDSsize
 
 
-    return skipChromosome
+                    # check that we still pass threshold
+                    regionLength = lastReadPos - region.start
+                    if regionPassesCriteria(regionFinder, region.numReads, foldRatio, regionLength):
+                        regionFinder.updateControlStatistics(peak, region.numReads, peakScore)
 
 
+            currentHitList = []
+            currentTotalWeight = 0
+            currentReadList = []
+            numStarts = 0
+            if badRegion:
+                badRegion = False
+                regionFinder.statistics["badRegionTrim"] += 1
 
 
-def calculatePValue(dataList):
-    dataList.sort()
-    listSize = float(len(dataList))
-    try:
-        poissonmean = sum(dataList) / listSize
-    except ZeroDivisionError:
-        poissonmean = 0
+        if pos not in currentHitList:
+            numStarts += 1
 
 
-    print "Poisson n=%d, p=%f" % (listSize, poissonmean)
-    p = math.exp(-poissonmean)
+        currentHitList.append(pos)
+        weight = read["weight"]
+        currentTotalWeight += weight
+        currentReadList.append({"start": pos, "sense": read["sense"], "weight": weight})
+        previousHit = pos
+
+    return backregions
 
 
-    return p, poissonmean
 
 
+def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
+               outfile, useMulti, doControl, controlRDS, combine5p):
 
 
-def learnShift(hitDict, hitSampleSize, mockRDS, chrom, doControl, useMulti, normalize, mockSampleSize, minRatio, maxSpacing, maxCoord,
-               stringency, readlen, minHits, logfilename, outfile, outfilename, minSites=30):
+    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
+                                  strand=regionFinder.stranded, combine5p=combine5p)
 
 
-    print "learning shift.... will need at least %d training sites" % minSites
-    previousHit = -1 * maxSpacing
-    hitList = [-1]
+    maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
+    print "learning shift.... will need at least 30 training sites"
+    stringency = regionFinder.stringency
+    previousHit = -1 * regionFinder.maxSpacing
+    positionList = [-1]
     totalWeight = 0
     readList = []
     shiftDict = {}
     count = 0
     numStarts = 0
     totalWeight = 0
     readList = []
     shiftDict = {}
     count = 0
     numStarts = 0
-    for read in hitDict[chrom]:
+    for read in hitDict[chromosome]:
         pos = read["start"]
         pos = read["start"]
-        sense = read["sense"]
-        weight = read["weight"]
-        if abs(pos - previousHit) > maxSpacing or pos == maxCoord:
-            sumAll = totalWeight
-            if normalize:
-                sumAll /= hitSampleSize
+        if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
+            if regionFinder.normalize:
+                totalWeight /= regionFinder.sampleRDSsize
 
 
-            regionStart = hitList[0]
-            regionStop = hitList[-1]
+            regionStart = positionList[0]
+            regionStop = positionList[-1]
             regionLength = regionStop - regionStart
             regionLength = regionStop - regionStart
-            # we're going to require stringent settings
-            if sumAll >= stringency * minHits and numStarts > stringency * minRatio and regionLength > stringency * readlen:
-                foldRatio = getFoldRatio(mockRDS, chrom, regionStart, regionStop, doControl, useMulti, normalize, mockSampleSize, sumAll, minRatio)
-
-                if foldRatio >= minRatio:
-                    localshift = getBestShiftForRegion(readList, regionStart, regionLength, doWeight=True)
-                    try:
-                        shiftDict[localshift] += 1
-                    except KeyError:
-                        shiftDict[localshift] = 1
-
+            if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
+                foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl)
+                if foldRatio >= regionFinder.minRatio:
+                    updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
                     count += 1
 
                     count += 1
 
-            hitList = []
+            positionList = []
             totalWeight = 0
             readList = []
             totalWeight = 0
             readList = []
-            numStarts = 0
 
 
-        if pos not in hitList:
+        if pos not in positionList:
             numStarts += 1
 
             numStarts += 1
 
-        hitList.append(pos)
+        positionList.append(pos)
+        weight = read["weight"]
         totalWeight += weight
         totalWeight += weight
-        readList.append({"start": pos, "sense": sense, "weight": weight})
+        readList.append({"start": pos, "sense": read["sense"], "weight": weight})
         previousHit = pos
 
         previousHit = pos
 
-    bestShift = 0
-    bestCount = 0
-    learningSettings = ["#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d" % (stringency, stringency * minHits,
-                                                                                                       stringency * minRatio, stringency * readlen),
-                        "#number of training examples: %d" % count]
-    outline = string.join(learningSettings, "\n")
+    outline = "#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d\n#number of training examples: %d" % (stringency,
+                                                                                                                               stringency * regionFinder.minHits,
+                                                                                                                               stringency * regionFinder.minRatio,
+                                                                                                                               stringency * regionFinder.readlen,
+                                                                                                                               count)
+
     print outline
     print outline
-    writeLog(logfilename, versionString, "%s%s" % (outfilename, outline))
-    if count < minSites:
+    writeLog(logfilename, versionString, outfilename + outline)
+    regionFinder.shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
+    outline = "#picked shiftValue to be %d" % regionFinder.shiftValue
+    print outline
+    print >> outfile, outline
+    writeLog(logfilename, versionString, outfilename + outline)
+
+
+def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
+    return abs(pos - previousHit) > maxSpacing or pos == maxCoord
+
+
+def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringency=1):
+    return sumAll >= stringency * regionFinder.minHits and numStarts > stringency * regionFinder.minRatio and regionLength > stringency * regionFinder.readlen
+
+
+def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
+    bestPos = peak.topPos[0]
+    peakScore = peak.smoothArray[bestPos]
+    if regionFinder.normalize:
+        peakScore /= totalReadCount
+
+    minSignalThresh = trimValue * peakScore
+    start = findStartEdgePosition(peak, minSignalThresh)
+    regionEndPoint = regionStop - region.start - 1
+    stop = findStopEdgePosition(peak, regionEndPoint, minSignalThresh)
+
+    regionStop = region.start + stop
+    region.start += start
+
+    trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, regionFinder.readlen, doWeight=True,
+                           leftPlus=regionFinder.doDirectionality, shift=peak.shift)
+
+    peak.numPlus = trimmedPeak.numPlus
+    peak.numLeftPlus = trimmedPeak.numLeftPlus
+    peak.topPos = trimmedPeak.topPos
+    peak.smoothArray = trimmedPeak.smoothArray
+
+    region.numReads = trimmedPeak.numHits
+    if regionFinder.normalize:
+        region.numReads /= totalReadCount
+
+    region.stop = regionStop + regionFinder.readlen - 1
+                          
+    return regionStop
+
+
+def findStartEdgePosition(peak, minSignalThresh):
+    start = 0
+    while not peakEdgeLocated(peak, start, minSignalThresh):
+        start += 1
+
+    return start
+
+
+def findStopEdgePosition(peak, stop, minSignalThresh):
+    while not peakEdgeLocated(peak, stop, minSignalThresh):
+        stop -= 1
+
+    return stop
+
+
+def peakEdgeLocated(peak, position, minSignalThresh):
+    return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
+
+
+def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl):
+    """ Fold ratio calculated is total read weight over control
+    """
+    #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
+    if doControl:
+        numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
+        if regionFinder.normalize:
+            numMock /= regionFinder.controlRDSsize
+
+        foldRatio = sumAll / numMock
+    else:
+        foldRatio = regionFinder.minRatio
+
+    return foldRatio
+
+
+def updateShiftDict(shiftDict, readList, regionStart, regionLength, readlen):
+    peak = findPeak(readList, regionStart, regionLength, readlen, doWeight=True, shift="auto")
+    try:
+        shiftDict[peak.shift] += 1
+    except KeyError:
+        shiftDict[peak.shift] = 1
+
+
+def getShiftValue(shiftDict, count, logfilename, outfilename):
+    if count < 30:
         outline = "#too few training examples to pick a shiftValue - defaulting to 0\n#consider picking a lower minimum or threshold"
         outline = "#too few training examples to pick a shiftValue - defaulting to 0\n#consider picking a lower minimum or threshold"
-        print >> outfile, outline
-        writeLog(logfilename, versionString, "%s%s" % (outfilename, outline))
+        print outline
+        writeLog(logfilename, versionString, outfilename + outline)
         shiftValue = 0
     else:
         shiftValue = 0
     else:
-        for shift in sorted(shiftDict):
-            if shiftDict[shift] > bestCount:
-                bestShift = shift
-                bestCount = shiftDict[shift]
-
-        shiftValue = bestShift
+        shiftValue = getBestShiftInDict(shiftDict)
         print shiftDict
 
         print shiftDict
 
-    outline = "#picked shiftValue to be %d" % shiftValue
-    print outline
-    print >> outfile, outline
-    writeLog(logfilename, versionString, "%s%s" % (outfilename, outline))
-
     return shiftValue
 
 
     return shiftValue
 
 
-def getFoldRatio(rds, chrom, start, stop, doControl, useMulti, normalize, sampleSize, sumAll, minRatio):
-    if doControl:
-        foldRatio = getFoldRatioFromRDS(rds, chrom, start, stop, useMulti, normalize, sampleSize, sumAll)
+def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
+              peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
+              numPlus, plusRatio):
+
+    if doDirectionality:
+        if leftPlusRatio < numLeft / numPlus:
+            plusP = plusRatio * 100.
+            leftP = 100. * numLeft / numPlus
+            # we have a region that passes all criteria
+            region = Region.DirectionalRegion(regionStart, regionStop,
+                                              factor, index, chromosome, sumAll,
+                                              foldRatio, multiP, plusP, leftP,
+                                              peakDescription, shift)
+
+        else:
+            raise RegionDirectionError
     else:
     else:
-        foldRatio = minRatio
+        # we have a region, but didn't check for directionality
+        region = Region.Region(regionStart, regionStop, factor, index, chromosome,
+                               sumAll, foldRatio, multiP, peakDescription, shift)
 
 
-    return foldRatio
+    return region
 
 
 
 
-def getFoldRatioFromRDS(rds, chrom, start, stop, useMulti, normalize, sampleSize, sumAll):
-    numMock = 1. + rds.getCounts(chrom, start, stop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
+def setMultireadPercentage(region, hitRDS, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
+    if doTrim:
+        sumMulti = hitRDS.getMultiCount(chromosome, region.start, lastReadPos)
+    else:
+        sumMulti = currentTotalWeight - currentUniqueCount
+
+    # normalize to RPM
     if normalize:
     if normalize:
-        numMock /= sampleSize
+        sumMulti /= hitRDSsize
 
 
-    foldRatio = sumAll / numMock
+    try:
+        multiP = 100. * (sumMulti / region.numReads)
+    except ZeroDivisionError:
+        return
 
 
-    return foldRatio
+    region.multiP = multiP
 
 
 
 
-def locateRegions(rds, rdsSampleSize, referenceRDS, referenceSampleSize, chrom, useMulti,
-                normalize, maxSpacing, doDirectionality, doTrim, minHits, minRatio, readlen,
-                shiftValue, minPeak, minPlusRatio, maxPlusRatio, leftPlusRatio, listPeak,
-                noMulti, doControl, factor, trimValue, outputRegionList=False):
+def regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
+    regionPasses = False
+    if regionPassesCriteria(regionFinder, region.numReads, region.foldRatio, regionLength):
+        if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
+            regionPasses = True
 
 
-    index = 0
-    totalRegionWeight = 0
-    failedCounter = 0
-    previousHit = - 1 * maxSpacing
-    currentHitList = [-1]
-    currentTotalWeight = 0
-    currentUniqReadCount = 0
-    currentReadList = []
-    regionWeights = []
-    outregions = []
-    numStarts = 0
-    hitDict = rds.getReadsDict(fullChrom=True, chrom=chrom, withWeight=True, doMulti=useMulti, findallOptimize=True)
-    maxCoord = rds.getMaxCoordinate(chrom, doMulti=useMulti)
-    for read in hitDict[chrom]:
-        pos = read["start"]
-        sense = read["sense"]
-        weight = read["weight"]
-        if abs(pos - previousHit) > maxSpacing or pos == maxCoord:
-            sumAll = currentTotalWeight
-            if normalize:
-                sumAll /= rdsSampleSize
-
-            regionStart = currentHitList[0]
-            regionStop = currentHitList[-1]
-            regionWeights.append(int(sumAll))
-            if sumAll >= minHits and numStarts > minRatio and (regionStop - regionStart) > readlen:
-                sumMulti = 0.
-                #first pass uses getFoldRatio on mockRDS as there may not be control
-                foldRatio = getFoldRatioFromRDS(referenceRDS, chrom, regionStart, regionStop, useMulti, normalize, referenceSampleSize, sumAll)
-                if foldRatio >= minRatio:
-                    # first pass, with absolute numbers
-                    peak = findPeak(currentReadList, regionStart, regionStop - regionStart, readlen, doWeight=True, leftPlus=doDirectionality, shift=shiftValue)
-
-                    bestPos = peak.topPos[0]
-                    numHits = peak.numHits
-                    peakScore = peak.smoothArray[bestPos]
-                    numPlus = peak.numPlus
-                    shift = peak.shift
-                    numLeft = peak.numLeft
-                    if normalize:
-                        peakScore /= rdsSampleSize
-
-                    if doTrim:
-                        minSignalThresh = trimValue * peakScore
-                        start = 0
-                        stop = regionStop - regionStart - 1
-                        startFound = False
-                        while not startFound:
-                            if peak.smoothArray[start] >= minSignalThresh or start == bestPos:
-                                startFound = True
-                            else:
-                                start += 1
-
-                        stopFound = False
-                        while not stopFound:
-                            if peak.smoothArray[stop] >= minSignalThresh or stop == bestPos:
-                                stopFound = True
-                            else:
-                                stop -= 1
-
-                        regionStop = regionStart + stop
-                        regionStart += start
-                        trimPeak = findPeak(currentReadList, regionStart, regionStop - regionStart, readlen, doWeight=True, leftPlus=doDirectionality, shift=shift)
-
-                        sumAll = trimPeak.numHits
-                        numPlus = trimPeak.numPlus
-                        numLeft = trimPeak.numLeft
-                        if normalize:
-                            sumAll /= rdsSampleSize
-
-                        foldRatio = getFoldRatio(referenceRDS, chrom, regionStart, regionStop, doControl, useMulti, normalize, referenceSampleSize, sumAll, minRatio)
-                        if outputRegionList:
-                            sumMulti = rds.getCounts(chrom, regionStart, regionStop, uniqs=False, multi=useMulti, splices=False, reportCombined=True)
-                        # just in case it changed, use latest data
-                        try:
-                            bestPos = trimPeak.topPos[0]
-                            peakScore = trimPeak.smoothArray[bestPos]
-                        except:
-                            continue
+    return regionPasses
 
 
-                        # normalize to RPM
-                        if normalize:
-                            peakScore /= rdsSampleSize
 
 
-                    elif outputRegionList:
-                        sumMulti = currentTotalWeight - currentUniqReadCount
+def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plusRatio):
 
 
-                    if outputRegionList:
-                        # normalize to RPM
-                        if normalize:
-                            sumMulti /= rdsSampleSize
+    if doDirectionality:
+        if leftPlusRatio < numLeft / numPlus:
+            region.plusP = plusRatio * 100.
+            region.leftP = 100. * numLeft / numPlus
+        else:
+            raise RegionDirectionError
 
 
-                        try:
-                            multiP = 100. * (sumMulti / sumAll)
-                        except:
-                            break
 
 
-                        if noMulti:
-                            multiP = 0.
+def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
+                                allregions, header):
 
 
-                    # check that we still pass threshold
-                    if sumAll >= minHits and  foldRatio >= minRatio and (regionStop - regionStart) > readlen:
-                        plusRatio = float(numPlus)/numHits
-                        if peakScore >= minPeak and minPlusRatio <= plusRatio <= maxPlusRatio:
-                            if outputRegionList:
-                                peakDescription = ""
-                                if listPeak:
-                                    peakDescription = "\t%d\t%.1f" % (regionStart + bestPos, peakScore)
-
-                            if doDirectionality:
-                                if leftPlusRatio < numLeft / numPlus:
-                                    index += 1
-                                    if outputRegionList:
-                                        plusP = plusRatio * 100.
-                                        leftP = 100. * numLeft / numPlus
-                                        # we have a region that passes all criteria
-                                        region = Region.DirectionalRegion(regionStart, regionStop + readlen - 1,
-                                                                          factor, index, chrom, sumAll,
-                                                                          foldRatio, multiP, plusP, leftP,
-                                                                          peakDescription, shift)
-                                        outregions.append(region)
-
-                                    totalRegionWeight += sumAll
-                                else:
-                                    failedCounter += 1
-                            else:
-                                # we have a region, but didn't check for directionality
-                                index += 1
-                                totalRegionWeight += sumAll
-                                if outputRegionList:
-                                    region = Region.Region(regionStart, regionStop + readlen - 1, factor, index, chrom,
-                                                           sumAll, foldRatio, multiP, peakDescription, shift)
-                                    outregions.append(region)
+    writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
+                           allregions, header, backregions=[], pValueType="self")
 
 
-            currentHitList = []
-            currentTotalWeight = 0
-            currentUniqReadCount = 0
-            currentReadList = []
-            numStarts = 0
 
 
-        if pos not in currentHitList:
-            numStarts += 1
+def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
+                           allregions, header, backregions=[], pValueType="none"):
 
 
-        currentHitList.append(pos)
-        currentTotalWeight += weight
-        if weight == 1.0:
-            currentUniqReadCount += 1
+    print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
+    if doPvalue:
+        if pValueType == "self":
+            poissonmean = calculatePoissonMean(allregions)
+        else:
+            poissonmean = calculatePoissonMean(backregions)
 
 
-        currentReadList.append({"start": pos, "sense": sense, "weight": weight})
-        previousHit = pos
+    print header
+    writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=regionFinder.reportshift, shiftDict=shiftDict)
 
 
-    statistics = {"index": index,
-                  "total": totalRegionWeight,
-                  "failed": failedCounter
-    }
 
 
-    if outputRegionList:
-        return statistics, regionWeights, outregions
-    else:
-        return statistics, regionWeights
+def calculatePoissonMean(dataList):
+    dataList.sort()
+    listSize = float(len(dataList))
+    try:
+        poissonmean = sum(dataList) / listSize
+    except ZeroDivisionError:
+        poissonmean = 0
 
 
+    print "Poisson n=%d, p=%f" % (listSize, poissonmean)
 
 
-def writeRegionsToFile(outfile, outregions, doPvalue, pValue, poissonmean, reportshift, shiftValue):
-    bestShift = 0
-    shiftDict = {}
+    return poissonmean
+
+
+def writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=0, reportshift=False, shiftDict={}):
     for region in outregions:
     for region in outregions:
-        # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
-        if reportshift:
-            outputList = [region.printRegionWithShift()]
-            if shiftValue == "auto":
-                try:
-                    shiftDict[region.shift] += 1
-                except KeyError:
-                    shiftDict[region.shift] = 1
-        else:
-            outputList = [region.printRegion()]
+        if shiftValue == "auto" and reportshift:
+            try:
+                shiftDict[region.shift] += 1
+            except KeyError:
+                shiftDict[region.shift] = 1
+
+        outline = getRegionString(region, reportshift)
 
         # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
         if doPvalue:
             sumAll = int(region.numReads)
 
         # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
         if doPvalue:
             sumAll = int(region.numReads)
-            for i in xrange(sumAll):
-                pValue *= poissonmean
-                pValue /= i+1
+            pValue = calculatePValue(sumAll, poissonmean)
+            outline += "\t%1.2g" % pValue
 
 
-            outputList.append("%1.2f" % pValue)
-
-        outline = string.join(outputList, "\t")
         print outline
         print >> outfile, outline
 
         print outline
         print >> outfile, outline
 
-    bestCount = 0
-    for shift in sorted(shiftDict):
-        if shiftDict[shift] > bestCount:
-            bestShift = shift
-            bestCount = shiftDict[shift]
 
 
-    return bestShift
+def calculatePValue(sum, poissonmean):
+    pValue = math.exp(-poissonmean)
+    for i in xrange(sum):
+        pValue *= poissonmean
+        pValue /= i+1
 
 
+    return pValue
 
 
-def getFooter(stats, doDirectionality, doRevBackground, shiftValue, reportshift, shiftModeValue):
-    footerList = ["#stats:\t%.1f RPM in %d regions" % (stats["total"], stats["index"])]
-    if doDirectionality:
-        footerList.append("#\t\t%d additional regions failed directionality filter" % stats["failed"])
+
+def getRegionString(region, reportShift):
+    if reportShift:
+        outline = region.printRegionWithShift()
+    else:
+        outline = region.printRegion()
+
+    return outline
+
+
+def getFooter(regionFinder, shiftDict, doRevBackground):
+    index = regionFinder.statistics["index"]
+    mIndex = regionFinder.statistics["mIndex"]
+    footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
+    if regionFinder.doDirectionality:
+        footerLines.append("#\t\t%d additional regions failed directionality filter" % regionFinder.statistics["failed"])
 
     if doRevBackground:
         try:
 
     if doRevBackground:
         try:
-            percent = min(100. * (float(stats["mIndex"])/stats["index"]), 100)
-        except (ValueError, ZeroDivisionError):
+            percent = min(100. * (float(mIndex)/index), 100.)
+        except ZeroDivisionError:
             percent = 0.
 
             percent = 0.
 
-        footerList.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (stats["mIndex"], stats["mTotal"], percent))
+        footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))
+
+    if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
+        bestShift = getBestShiftInDict(shiftDict)
+        footerLines.append("#mode of shift values: %d" % bestShift)
+
+    if regionFinder.statistics["badRegionTrim"] > 0:
+        footerLines.append("#%d regions discarded due to trimming problems" % regionFinder.statistics["badRegionTrim"])
+
+    return string.join(footerLines, "\n")
 
 
-    if shiftValue == "auto" and reportshift:
-        footerList.append("#mode of shift values: %d" % shiftModeValue)
 
 
-    footer = string.join(footerList, "\n")
+def getBestShiftInDict(shiftDict):
+    return max(shiftDict.iteritems(), key=operator.itemgetter(1))[0]
 
 
-    return footer
 
 if __name__ == "__main__":
 
 if __name__ == "__main__":
-    main(sys.argv)
\ No newline at end of file
+    main(sys.argv)