rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
authorSean Upchurch <sau@caltech.edu>
Thu, 21 Jul 2011 19:48:10 +0000 (12:48 -0700)
committerSean Upchurch <sau@caltech.edu>
Thu, 21 Jul 2011 19:48:10 +0000 (12:48 -0700)
and incorrect initial specifications

MakeRdsFromBam.py
ReadDataset.py
combinerds.py
findall.py

index 969d4ccf556524200904e6f78f53ddc1cd1018b2..13c901332b2c54dcf6b4033a6633ec0e074da782 100644 (file)
@@ -21,7 +21,7 @@ from commoncode import writeLog, getConfigParser, getConfigBoolOption, getConfig
 import ReadDataset
 
 INSERT_SIZE = 100000
 import ReadDataset
 
 INSERT_SIZE = 100000
-verstring = "makeRdsFromBam: version 1.0"
+verstring = "makeRdsFromBam: version 1.1"
 
 
 def main(argv=None):
 
 
 def main(argv=None):
@@ -54,8 +54,13 @@ def main(argv=None):
         print "no outrdsfile specified - see --help for usage"
         sys.exit(1)
 
         print "no outrdsfile specified - see --help for usage"
         sys.exit(1)
 
+    if options.rnaDataType:
+        dataType = "RNA"
+    else:
+        dataType = "DNA"
+
     makeRdsFromBam(label, samFileName, outDbName, options.init, options.doIndex, options.useSamFile,
     makeRdsFromBam(label, samFileName, outDbName, options.init, options.doIndex, options.useSamFile,
-                   options.cachePages, options.maxMultiReadCount, options.rnaDataType, options.trimReadID)
+                   options.cachePages, options.maxMultiReadCount, dataType, options.trimReadID)
 
 
 def getParser(usage):
 
 
 def getParser(usage):
@@ -92,26 +97,14 @@ def getParser(usage):
 
 
 def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useSamFile=False,
 
 
 def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useSamFile=False,
-                   cachePages=100000, maxMultiReadCount=10, rnaDataType=False, trimReadID=True):
+                   cachePages=100000, maxMultiReadCount=10, dataType="DNA", trimReadID=True):
 
     if useSamFile:
         fileMode = "r"
     else:
         fileMode = "rb"
 
 
     if useSamFile:
         fileMode = "r"
     else:
         fileMode = "rb"
 
-    try:
-        samfile = pysam.Samfile(samFileName, fileMode)
-    except ValueError:
-        print "samfile index not found"
-        sys.exit(1)
-
-    if rnaDataType:
-        dataType = "RNA"
-    else:
-        dataType = "DNA"
-
     writeLog("%s.log" % outDbName, verstring, string.join(sys.argv[1:]))
     writeLog("%s.log" % outDbName, verstring, string.join(sys.argv[1:]))
-
     rds = ReadDataset.ReadDataset(outDbName, init, dataType, verbose=True)
     if not init and doIndex:
         try:
     rds = ReadDataset.ReadDataset(outDbName, init, dataType, verbose=True)
     if not init and doIndex:
         try:
@@ -145,7 +138,8 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
                        "unique": 0,
                        "multi": 0,
                        "multiDiscard": 0,
                        "unique": 0,
                        "multi": 0,
                        "multiDiscard": 0,
-                       "splice": 0
+                       "splice": 0,
+                       "multisplice": 0
     }
 
     readsize = 0
     }
 
     readsize = 0
@@ -153,14 +147,25 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
     uniqueInsertList = []
     multiInsertList = []
     spliceInsertList = []
     uniqueInsertList = []
     multiInsertList = []
     spliceInsertList = []
+    multispliceInsertList = []
 
 
-    processedEntryDict = {}
     uniqueReadDict = {}
     multiReadDict = {}
     uniqueReadDict = {}
     multiReadDict = {}
+    multispliceReadDict = {}
     spliceReadDict = {}
     spliceReadDict = {}
+    multireadCounts = getMultiReadIDCounts(samFileName, fileMode)
 
 
-    samFileIterator = samfile.fetch(until_eof=True)
+    for readID in multireadCounts:
+        if multireadCounts[readID] > maxMultiReadCount:
+            totalReadCounts["multiDiscard"] += 1
+
+    try:
+        samfile = pysam.Samfile(samFileName, fileMode)
+    except ValueError:
+        print "samfile index not found"
+        sys.exit(1)
 
 
+    samFileIterator = samfile.fetch(until_eof=True)
     for read in samFileIterator:
         if read.is_unmapped:
             totalReadCounts["unmapped"] += 1
     for read in samFileIterator:
         if read.is_unmapped:
             totalReadCounts["unmapped"] += 1
@@ -172,39 +177,28 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
             if init:
                 rds.insertMetadata([("readsize", readsize)])
 
             if init:
                 rds.insertMetadata([("readsize", readsize)])
 
-        #Build the read dictionaries
-        try:
-            readSequence = read.seq
-        except KeyError:
-            readSequence = ""
-
         pairReadSuffix = getPairedReadNumberSuffix(read)
         pairReadSuffix = getPairedReadNumberSuffix(read)
-        readName = "%s%s%s" % (read.qname, readSequence, pairReadSuffix)
+        readName = "%s%s" % (read.qname, pairReadSuffix)
         if trimReadID:
             rdsEntryName = "%s:%s:%d%s" % (label, read.qname, totalReadCounts["total"], pairReadSuffix)
         else:
             rdsEntryName = read.qname
 
         if trimReadID:
             rdsEntryName = "%s:%s:%d%s" % (label, read.qname, totalReadCounts["total"], pairReadSuffix)
         else:
             rdsEntryName = read.qname
 
-        if processedEntryDict.has_key(readName):
-            if isSpliceEntry(read.cigar):
-                if spliceReadDict.has_key(readName):
-                    del spliceReadDict[readName]
-            else:
-                if uniqueReadDict.has_key(readName):
-                    del uniqueReadDict[readName]
-
-                if multiReadDict.has_key(readName):
-                    (read, priorCount, rdsEntryName) = multiReadDict[readName]
-                    count = priorCount + 1
-                    multiReadDict[readName] = (read, count, rdsEntryName)
-                else:
-                    multiReadDict[readName] = (read, 1, rdsEntryName)
-        else:
-            processedEntryDict[readName] = ""
+        try:
+            count = multireadCounts[readName]
+        except KeyError:
+            count = 1
+
+        if count == 1:
             if isSpliceEntry(read.cigar):
                 spliceReadDict[readName] = (read,rdsEntryName)
             else:
                 uniqueReadDict[readName] = (read, rdsEntryName)
             if isSpliceEntry(read.cigar):
                 spliceReadDict[readName] = (read,rdsEntryName)
             else:
                 uniqueReadDict[readName] = (read, rdsEntryName)
+        elif count <= maxMultiReadCount:
+            if isSpliceEntry(read.cigar):
+                multispliceReadDict[readName] = (read, count, rdsEntryName)
+            else:
+                multiReadDict[readName] = (read, count, rdsEntryName)
 
         if totalReadCounts["total"] % INSERT_SIZE == 0:
             for entry in uniqueReadDict.keys():
 
         if totalReadCounts["total"] % INSERT_SIZE == 0:
             for entry in uniqueReadDict.keys():
@@ -213,20 +207,23 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
                 uniqueInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize))
                 totalReadCounts["unique"] += 1
 
                 uniqueInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize))
                 totalReadCounts["unique"] += 1
 
-            for entry in spliceReadDict.keys():
-                (readData, rdsEntryName) = spliceReadDict[entry]
-                chrom = samfile.getrname(readData.rname)
-                spliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize))
-                totalReadCounts["splice"] += 1
-
             for entry in multiReadDict.keys():
                 (readData, count, rdsEntryName) = multiReadDict[entry]
                 chrom = samfile.getrname(readData.rname)
             for entry in multiReadDict.keys():
                 (readData, count, rdsEntryName) = multiReadDict[entry]
                 chrom = samfile.getrname(readData.rname)
-                if count > maxMultiReadCount:
-                    totalReadCounts["multiDiscard"] += 1
-                else:
-                    multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count)) 
-                    totalReadCounts["multi"] += 1
+                multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count)) 
+
+            if dataType == "RNA":
+                for entry in spliceReadDict.keys():
+                    (readData, rdsEntryName) = spliceReadDict[entry]
+                    chrom = samfile.getrname(readData.rname)
+                    spliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize))
+                    totalReadCounts["splice"] += 1
+
+                for entry in multispliceReadDict.keys():
+                    (readData, count, rdsEntryName) = multispliceReadDict[entry]
+                    chrom = samfile.getrname(readData.rname)
+                    multispliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize, weight=count))
+                    totalReadCounts["multisplice"] += 1
 
             rds.insertUniqs(uniqueInsertList)
             rds.insertMulti(multiInsertList)
 
             rds.insertUniqs(uniqueInsertList)
             rds.insertMulti(multiInsertList)
@@ -238,10 +235,12 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
                 rds.insertSplices(spliceInsertList)
                 spliceInsertList = []
                 spliceReadDict = {}
                 rds.insertSplices(spliceInsertList)
                 spliceInsertList = []
                 spliceReadDict = {}
+                rds.insertMultisplices(multispliceInsertList)
+                multispliceInsertList = []
+                multispliceReadDict = {}
 
             print ".",
             sys.stdout.flush()
 
             print ".",
             sys.stdout.flush()
-            processedEntryDict = {}
 
         totalReadCounts["total"] += 1
 
 
         totalReadCounts["total"] += 1
 
@@ -258,13 +257,9 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
         for entry in multiReadDict.keys():
             (readData, count, rdsEntryName) = multiReadDict[entry]
             chrom = samfile.getrname(readData.rname)
         for entry in multiReadDict.keys():
             (readData, count, rdsEntryName) = multiReadDict[entry]
             chrom = samfile.getrname(readData.rname)
-            if count > maxMultiReadCount:
-                totalReadCounts["multiDiscard"] += 1
-            else:
-                multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count))
-                totalReadCounts["multi"] += 1
+            multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count))
 
 
-        totalReadCounts["multi"] += len(multiInsertList)
+        rds.insertMulti(multiInsertList)
 
     if len(spliceReadDict.keys()) > 0 and dataType == "RNA":
         for entry in spliceReadDict.keys():
 
     if len(spliceReadDict.keys()) > 0 and dataType == "RNA":
         for entry in spliceReadDict.keys():
@@ -275,12 +270,23 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
 
         rds.insertSplices(spliceInsertList)
 
 
         rds.insertSplices(spliceInsertList)
 
+    if len(multispliceReadDict.keys()) > 0 and dataType == "RNA":
+        for entry in multispliceReadDict.keys():
+            (readData, count, rdsEntryName) = multispliceReadDict[entry]
+            chrom = samfile.getrname(readData.rname)
+            multispliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize, weight=count))
+            totalReadCounts["multisplice"] += 1
+
+        rds.insertMultisplices(multispliceInsertList)
+
+    totalReadCounts["multi"] = len(multireadCounts) - totalReadCounts["multiDiscard"] - totalReadCounts["multisplice"]
     countStringList = ["\n%d unmapped reads discarded" % totalReadCounts["unmapped"]]
     countStringList.append("%d unique reads" % totalReadCounts["unique"])
     countStringList.append("%d multi reads" % totalReadCounts["multi"])
     countStringList.append("%d multi reads count > %d discarded" % (totalReadCounts["multiDiscard"], maxMultiReadCount))
     if dataType == "RNA":
         countStringList.append("%d spliced reads" % totalReadCounts["splice"])
     countStringList = ["\n%d unmapped reads discarded" % totalReadCounts["unmapped"]]
     countStringList.append("%d unique reads" % totalReadCounts["unique"])
     countStringList.append("%d multi reads" % totalReadCounts["multi"])
     countStringList.append("%d multi reads count > %d discarded" % (totalReadCounts["multiDiscard"], maxMultiReadCount))
     if dataType == "RNA":
         countStringList.append("%d spliced reads" % totalReadCounts["splice"])
+        countStringList.append("%d spliced multireads" % totalReadCounts["multisplice"])
 
     print string.join(countStringList, "\n")
     outputCountText = string.join(countStringList, "\t")
 
     print string.join(countStringList, "\n")
     outputCountText = string.join(countStringList, "\t")
@@ -295,6 +301,29 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
             rds.buildIndex(defaultCacheSize)
 
 
             rds.buildIndex(defaultCacheSize)
 
 
+def getMultiReadIDCounts(samFileName, fileMode):
+    try:
+        samfile = pysam.Samfile(samFileName, fileMode)
+    except ValueError:
+        print "samfile index not found"
+        sys.exit(1)
+
+    readIDCounts = {}
+    for read in samfile.fetch(until_eof=True):
+        pairReadSuffix = getPairedReadNumberSuffix(read)
+        readName = "%s%s" % (read.qname, pairReadSuffix)
+        try:
+            readIDCounts[readName] += 1
+        except KeyError:
+            readIDCounts[readName] = 1
+
+    for readID in readIDCounts.keys():
+        if readIDCounts[readID] == 1:
+            del readIDCounts[readID]
+
+    return readIDCounts
+
+
 def getRDSEntry(alignedRead, readName, chrom, readSize, weight=1):
     start = int(alignedRead.pos)
     stop = int(start + readSize)
 def getRDSEntry(alignedRead, readName, chrom, readSize, weight=1):
     start = int(alignedRead.pos)
     stop = int(start + readSize)
@@ -308,11 +337,11 @@ def getRDSEntry(alignedRead, readName, chrom, readSize, weight=1):
     return (readName, chrom, start, stop, sense, 1.0/weight, '', mismatches)
 
 
     return (readName, chrom, start, stop, sense, 1.0/weight, '', mismatches)
 
 
-def getRDSSpliceEntry(alignedRead, readName, chrom, readSize):
-    (readName, chrom, start, stop, sense, weight, flag, mismatches) = getRDSEntry(alignedRead, readName, chrom, readSize)
+def getRDSSpliceEntry(alignedRead, readName, chrom, readSize, weight=1):
+    (readName, chrom, start, stop, sense, weight, flag, mismatches) = getRDSEntry(alignedRead, readName, chrom, readSize, weight)
     startL, startR, stopL, stopR = getSpliceBounds(start, readSize, alignedRead.cigar)
     
     startL, startR, stopL, stopR = getSpliceBounds(start, readSize, alignedRead.cigar)
     
-    return (readName, chrom, startL, stopL, startR, stopR, sense, 1.0, "", mismatches)
+    return (readName, chrom, startL, stopL, startR, stopR, sense, weight, "", mismatches)
 
 
 def getPairedReadNumberSuffix(read):
 
 
 def getPairedReadNumberSuffix(read):
index c9d2a0bf2b4eb56e0fa906be2319fcac866c30b3..71544ca2420d2489fca13452976d930a6c55a9c5 100644 (file)
@@ -6,7 +6,7 @@ import os
 from array import array
 from commoncode import getReverseComplement, getConfigParser, getConfigOption
 
 from array import array
 from commoncode import getReverseComplement, getConfigParser, getConfigOption
 
-currentRDSVersion = "2.0"
+currentRDSVersion = "2.1"
 
 
 class ReadDatasetError(Exception):
 
 
 class ReadDatasetError(Exception):
@@ -35,6 +35,9 @@ class ReadDataset():
         self.memCursor = ""
         self.cachedDBFile = ""
 
         self.memCursor = ""
         self.cachedDBFile = ""
 
+        if initialize and datasetType not in ["DNA", "RNA"]:
+            raise ReadDatasetError("failed to initialize: datasetType must be 'DNA' or 'RNA'")
+
         if cache:
             if verbose:
                 print "caching ...."
         if cache:
             if verbose:
                 print "caching ...."
@@ -48,11 +51,7 @@ class ReadDataset():
         self.dbcon.row_factory = sqlite.Row
         self.dbcon.execute("PRAGMA temp_store = MEMORY")
         if initialize:
         self.dbcon.row_factory = sqlite.Row
         self.dbcon.execute("PRAGMA temp_store = MEMORY")
         if initialize:
-            if datasetType not in ["DNA", "RNA"]:
-                raise ReadDatasetError("failed to initialize: datasetType must be 'DNA' or 'RNA'")
-            else:
-                self.dataType = datasetType
-
+            self.dataType = datasetType
             self.initializeTables(self.dbcon)
         else:
             metadata = self.getMetadata("dataType")
             self.initializeTables(self.dbcon)
         else:
             metadata = self.getMetadata("dataType")
@@ -69,38 +68,7 @@ class ReadDataset():
                 self.rdsVersion = "pre-1.0"
 
         if verbose:
                 self.rdsVersion = "pre-1.0"
 
         if verbose:
-            if initialize:
-                print "INITIALIZED dataset %s" % datafile
-            else:
-                print "dataset %s" % datafile
-
-            metadata = self.getMetadata()
-            print "metadata:"
-            pnameList = metadata.keys()
-            pnameList.sort()
-            for pname in pnameList:
-                print "\t" + pname + "\t" + metadata[pname]
-
-            if reportCount:
-                ucount = self.getUniqsCount()
-                mcount = self.getMultiCount()
-                if self.dataType == "DNA" and not initialize:
-                    try:
-                        print "\n%d unique reads and %d multireads" % (int(ucount), int(mcount))
-                    except ValueError:
-                        print "\n%s unique reads and %s multireads" % (ucount, mcount)
-                elif self.dataType == "RNA" and not initialize:
-                    scount = self.getSplicesCount()
-                    try:
-                        print "\n%d unique reads, %d spliced reads and %d multireads" % (int(ucount), int(scount), int(mcount))
-                    except ValueError:
-                        print "\n%s unique reads, %s spliced reads and %s multireads" % (ucount, scount, mcount)
-
-            print "default cache size is %d pages" % self.getDefaultCacheSize()
-            if self.hasIndex():
-                print "found index"
-            else:
-                print "not indexed"
+            self.printRDSInfo(datafile, reportCount, initialize)
 
 
     def __len__(self):
 
 
     def __len__(self):
@@ -124,6 +92,39 @@ class ReadDataset():
             self.uncacheDB()
 
 
             self.uncacheDB()
 
 
+    def printRDSInfo(self, datafile, reportCount, initialize):
+        if initialize:
+            print "INITIALIZED dataset %s" % datafile
+        else:
+            print "dataset %s" % datafile
+
+        metadata = self.getMetadata()
+        print "metadata:"
+        pnameList = metadata.keys()
+        pnameList.sort()
+        for pname in pnameList:
+            print "\t" + pname + "\t" + metadata[pname]
+
+        if reportCount and not initialize:
+            self.printReadCounts()
+
+        print "default cache size is %d pages" % self.getDefaultCacheSize()
+        if self.hasIndex():
+            print "found index"
+        else:
+            print "not indexed"
+
+
+    def printReadCounts(self):
+        ucount = self.getUniqsCount()
+        mcount = self.getMultiCount()
+        if self.dataType == "DNA":
+            print "\n%d unique reads and %d multireads" % (ucount, mcount)
+        elif self.dataType == "RNA":
+            scount = self.getSplicesCount()
+            print "\n%d unique reads, %d spliced reads and %d multireads" % (ucount, scount, mcount)
+
+
     def cacheDB(self, filename):
         """ copy geneinfoDB to a local cache.
         """
     def cacheDB(self, filename):
         """ copy geneinfoDB to a local cache.
         """
@@ -246,6 +247,10 @@ class ReadDataset():
             tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
             dbConnection.execute("create table splices %s" % tableSchema)
 
             tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
             dbConnection.execute("create table splices %s" % tableSchema)
 
+            positionSchema = "startL int, stopL int, startR int, stopR int"
+            tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
+            dbConnection.execute("create table multisplices %s" % tableSchema)
+
         dbConnection.commit()
 
 
         dbConnection.commit()
 
 
@@ -993,6 +998,14 @@ class ReadDataset():
         self.dbcon.commit()
 
 
         self.dbcon.commit()
 
 
+    def insertMultisplices(self, valuesList):
+        """ inserts a list of (readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch)
+        into the multisplices table.
+        """
+        self.dbcon.executemany("insert into multisplices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", valuesList)
+        self.dbcon.commit()
+
+
     def flagReads(self, regionsList, uniqs=True, multi=False, splices=False, sense="both"):
         """ update reads on file database in a list region of regions for a chromosome to have a new flag.
             regionsList must have 4 fields per region of the form (flag, chrom, start, stop) or, with
     def flagReads(self, regionsList, uniqs=True, multi=False, splices=False, sense="both"):
         """ update reads on file database in a list region of regions for a chromosome to have a new flag.
             regionsList must have 4 fields per region of the form (flag, chrom, start, stop) or, with
index 2878423a9716a0e4f31d0a1738a0770339c8ed6c..2695e6b512c677735fc61be2931108e947d0ab43 100755 (executable)
@@ -35,9 +35,7 @@ def main(argv=None):
     combinerds(datafile, infileList, options.tableList, options.withFlag, options.doIndex, options.cachePages, options.doInit, options.initRNA)
 
 
     combinerds(datafile, infileList, options.tableList, options.withFlag, options.doIndex, options.cachePages, options.doInit, options.initRNA)
 
 
-def makeParser():
-    usage = __doc__
-
+def makeParser(usage):
     parser = optparse.OptionParser(usage=usage)
     parser.add_option("--table", action="append", dest="tablelist")
     parser.add_option("--init", action="store_true", dest="doInit")
     parser = optparse.OptionParser(usage=usage)
     parser.add_option("--table", action="append", dest="tablelist")
     parser.add_option("--init", action="store_true", dest="doInit")
index 663d7e075d5d30535da6eed6547d4da2b4970be2..f2c70432b1ccc477ca3ccee55b708a0e3e82f4bf 100755 (executable)
@@ -60,11 +60,12 @@ print versionString
 
 class RegionDirectionError(Exception):
     pass
 
 class RegionDirectionError(Exception):
     pass
-
+            
 
 class RegionFinder():
 
 class RegionFinder():
-    def __init__(self, minRatio, minPeak, minPlusRatio, maxPlusRatio, leftPlusRatio, strandfilter, minHits, trimValue, doTrim, doDirectionality,
-                 shiftValue, maxSpacing):
+    def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
+                 minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
+                 normalize=True, listPeak=False, reportshift=False, stringency=1.0):
 
         self.statistics = {"index": 0,
                            "total": 0,
 
         self.statistics = {"index": 0,
                            "total": 0,
@@ -74,6 +75,8 @@ class RegionFinder():
                            "badRegionTrim": 0
         }
 
                            "badRegionTrim": 0
         }
 
+        self.regionLabel = label
+        self.rnaSettings = False
         self.controlRDSsize = 1
         self.sampleRDSsize = 1
         self.minRatio = minRatio
         self.controlRDSsize = 1
         self.sampleRDSsize = 1
         self.minRatio = minRatio
@@ -107,17 +110,89 @@ class RegionFinder():
 
         self.shiftValue = shiftValue
         self.maxSpacing = maxSpacing
 
         self.shiftValue = shiftValue
         self.maxSpacing = maxSpacing
+        self.withFlag = withFlag
+        self.normalize = normalize
+        self.listPeak = listPeak
+        self.reportshift = reportshift
+        self.stringency = max(stringency, 1.0)
 
 
 
 
-    def useRNASettings(self):
+    def useRNASettings(self, readlen):
+        self.rnaSettings = True
         self.shiftValue = 0
         self.doTrim = False
         self.doDirectionality = False
         self.shiftValue = 0
         self.doTrim = False
         self.doDirectionality = False
+        self.maxSpacing = readlen
+
+
+    def getHeader(self, doPvalue):
+        if self.normalize:
+            countType = "RPM"
+        else:
+            countType = "COUNT"
+
+        headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
+
+        if self.doDirectionality:
+            headerFields.append("plus%\tleftPlus%")
+
+        if self.listPeak:
+            headerFields.append("peakPos\tpeakHeight")
+
+        if self.reportshift:
+            headerFields.append("readShift")
+
+        if doPvalue:
+            headerFields.append("pValue")
+
+        return string.join(headerFields, "\t")
+
+
+    def printSettings(self, doRevBackground, ptype, doControl, useMulti, doCache, pValueType):
+        print
+        self.printStatusMessages(doRevBackground, ptype, doControl, useMulti)
+        self.printOptionsSummary(useMulti, doCache, pValueType)
+
+
+    def printStatusMessages(self, doRevBackground, ptype, doControl, useMulti):
+        if self.shiftValue == "learn":
+            print "Will try to learn shift"
+
+        if self.normalize:
+            print "Normalizing to RPM"
+
+        if doRevBackground:
+            print "Swapping IP and background to calculate FDR"
+
+        if ptype != "":
+            if ptype in ["NONE", "SELF"]:
+                pass
+            elif ptype == "BACK":
+                if doControl and doRevBackground:
+                    pass
+                else:
+                    print "must have a control dataset and -revbackground for pValue type 'back'"
+            else:
+                print "could not use pValue type : %s" % ptype
+
+        if self.withFlag != "":
+            print "restrict to flag = %s" % self.withFlag
+
+        if not useMulti:
+            print "using unique reads only"
+
+        if self.rnaSettings:
+            print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
+
+        if self.strandfilter == "plus":
+            print "only analyzing reads on the plus strand"
+        elif self.strandfilter == "minus":
+            print "only analyzing reads on the minus strand"
 
 
 
 
-    def printOptionsSummary(self, listPeak, useMulti, doCache, pValueType):
+    def printOptionsSummary(self, useMulti, doCache, pValueType):
 
 
-        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, listPeak, not useMulti, doCache)
+        print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, self.listPeak, not useMulti, doCache)
         print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
         try:
             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
         print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
         try:
             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
@@ -125,8 +200,7 @@ class RegionFinder():
             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
 
 
             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
 
 
-    def getAnalysisDescription(self, hitfile, controlfile, doControl,
-                                 withFlag, listPeak, useMulti, doCache, pValueType):
+    def getAnalysisDescription(self, hitfile, useMulti, doCache, pValueType, controlfile, doControl):
 
         description = ["#ERANGE %s" % versionString]
         if doControl:
 
         description = ["#ERANGE %s" % versionString]
         if doControl:
@@ -134,10 +208,10 @@ class RegionFinder():
         else:
             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
 
         else:
             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
 
-        if withFlag != "":
-            description.append("#restrict to Flag = %s" % withFlag)
+        if self.withFlag != "":
+            description.append("#restrict to Flag = %s" % self.withFlag)
 
 
-        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, listPeak, not useMulti, doCache))
+        description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, self.listPeak, not useMulti, doCache))
         description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
         try:
             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
         description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
         try:
             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
@@ -147,6 +221,22 @@ class RegionFinder():
         return string.join(description, "\n")
 
 
         return string.join(description, "\n")
 
 
+    def updateControlStatistics(self, peak, sumAll, peakScore):
+
+        plusRatio = float(peak.numPlus)/peak.numHits
+        if peakScore >= self.minPeak and self.minPlusRatio <= plusRatio <= self.maxPlusRatio:
+            if self.doDirectionality:
+                if self.leftPlusRatio < peak.numLeft / peak.numPlus:
+                    self.statistics["mIndex"] += 1
+                    self.statistics["mTotal"] += sumAll
+                else:
+                    self.statistics["failed"] += 1
+            else:
+                # we have a region, but didn't check for directionality
+                self.statistics["mIndex"] += 1
+                self.statistics["mTotal"] += sumAll
+
+
 def usage():
     print __doc__
 
 def usage():
     print __doc__
 
@@ -186,13 +276,16 @@ def main(argv=None):
     else:
         outputMode = "w"
 
     else:
         outputMode = "w"
 
-    findall(factor, hitfile, outfilename, options.minHits, options.minRatio, options.maxSpacing, options.listPeak, shiftValue,
-            options.stringency, options.reportshift,
-            options.minPlusRatio, options.maxPlusRatio, options.leftPlusRatio, options.minPeak,
-            options.normalize, options.logfilename, options.withFlag, options.doDirectionality,
-            options.trimValue, options.doTrim, outputMode, options.rnaSettings,
-            options.cachePages, options.ptype, options.controlfile, options.doRevBackground, options.useMulti,
-            options.strandfilter, options.combine5p)
+    regionFinder = RegionFinder(factor, minRatio=options.minRatio, minPeak=options.minPeak, minPlusRatio=options.minPlusRatio,
+                                maxPlusRatio=options.maxPlusRatio, leftPlusRatio=options.leftPlusRatio, strandfilter=options.strandfilter,
+                                minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
+                                doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
+                                withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
+                                reportshift=options.reportshift, stringency=options.stringency)
+
+    findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
+            options.cachePages, options.ptype, options.controlfile, options.doRevBackground,
+            options.useMulti, options.combine5p)
 
 
 def makeParser():
 
 
 def makeParser():
@@ -270,72 +363,44 @@ def makeParser():
     return parser
 
 
     return parser
 
 
-def findall(factor, hitfile, outfilename, minHits=4.0, minRatio=4.0, maxSpacing=50, listPeak=False, shiftValue=0,
-            stringency=4.0, reportshift=False, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, minPeak=0.5,
-            normalize=True, logfilename="findall.log", withFlag="", doDirectionality=True, trimValue=0.1, doTrim=True,
-            outputMode="w", rnaSettings=False, cachePages=None, ptype="", controlfile=None, doRevBackground=False,
-            useMulti=True, strandfilter="", combine5p=False):
-
-    regionFinder = RegionFinder(minRatio, minPeak, minPlusRatio, maxPlusRatio, leftPlusRatio, strandfilter, minHits, trimValue,
-                                doTrim, doDirectionality, shiftValue, maxSpacing)
-
-    doControl = controlfile is not None
-    pValueType = getPValueType(ptype, doControl, doRevBackground)
-    doPvalue = not pValueType == "none"
-
-    if rnaSettings:
-        regionFinder.useRNASettings()
+def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False, cachePages=None,
+            ptype="", controlfile=None, doRevBackground=False, useMulti=True, combine5p=False):
 
 
+    writeLog(logfilename, versionString, string.join(sys.argv[1:]))
     doCache = cachePages is not None
     doCache = cachePages is not None
-    printStatusMessages(regionFinder.shiftValue, normalize, doRevBackground, ptype, doControl, withFlag, useMulti, rnaSettings, regionFinder.strandfilter)
     controlRDS = None
     controlRDS = None
+    doControl = controlfile is not None
     if doControl:
         print "\ncontrol:" 
         controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
     if doControl:
         print "\ncontrol:" 
         controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
+        regionFinder.controlRDSsize = len(controlRDS) / 1000000.
 
     print "\nsample:" 
     hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
 
     print "\nsample:" 
     hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
-
-    print
+    regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
+    pValueType = getPValueType(ptype, doControl, doRevBackground)
+    doPvalue = not pValueType == "none"
     regionFinder.readlen = hitRDS.getReadSize()
     if rnaSettings:
     regionFinder.readlen = hitRDS.getReadSize()
     if rnaSettings:
-        regionFinder.maxSpacing = regionFinder.readlen
-
-    writeLog(logfilename, versionString, string.join(sys.argv[1:]))
-    regionFinder.printOptionsSummary(listPeak, useMulti, doCache, pValueType)
-
-    regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
-    if doControl:
-        regionFinder.controlRDSsize = len(controlRDS) / 1000000.
+        regionFinder.useRNASettings(regionFinder.readlen)
 
 
+    regionFinder.printSettings(doRevBackground, ptype, doControl, useMulti, doCache, pValueType)
     outfile = open(outfilename, outputMode)
     outfile = open(outfilename, outputMode)
-    print >> outfile, regionFinder.getAnalysisDescription(hitfile, controlfile, doControl,
-                                                          withFlag, listPeak, useMulti, doCache, pValueType)
-
-    header = getHeader(normalize, regionFinder.doDirectionality, listPeak, reportshift, doPvalue)
-    print >> outfile, header
-    if minRatio < minPeak:
-        minPeak = minRatio
-
+    header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl)
     shiftDict = {}
     shiftDict = {}
-    hitChromList = hitRDS.getChromosomes()
-    stringency = max(stringency, 1.0)
-    chromosomeList = getChromosomeListToProcess(hitChromList, controlRDS, doControl)
+    chromosomeList = getChromosomeListToProcess(hitRDS, controlRDS, doControl)
     for chromosome in chromosomeList:
     for chromosome in chromosomeList:
-        allregions, outregions = findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, outfilename,
-                                                 outfile, stringency, normalize, useMulti, doControl, withFlag, combine5p,
-                                                 factor, listPeak)
+        if regionFinder.shiftValue == "learn":
+            learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
 
 
+        allregions, outregions = findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
         if doRevBackground:
         if doRevBackground:
-            #TODO: this is *almost* the same calculation - there are small yet important differences
-            backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normalize, useMulti, factor)
-            writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
-                                   allregions, header, backregions=backregions, pValueType=pValueType)
+            backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti)
+            writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
         else:
         else:
-            writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
-                                        allregions, header)
+            writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header)
 
 
-    footer = getFooter(regionFinder, shiftDict, reportshift, doRevBackground)
+    footer = getFooter(regionFinder, shiftDict, doRevBackground)
     print footer
     print >> outfile, footer
     outfile.close()
     print footer
     print >> outfile, footer
     outfile.close()
@@ -358,42 +423,6 @@ def getPValueType(ptype, doControl, doRevBackground):
     return pValueType
 
 
     return pValueType
 
 
-def printStatusMessages(shiftValue, normalize, doRevBackground, ptype, doControl, withFlag, useMulti, rnaSettings, strandfilter):
-    if shiftValue == "learn":
-        print "Will try to learn shift"
-
-    if normalize:
-        print "Normalizing to RPM"
-
-    if doRevBackground:
-        print "Swapping IP and background to calculate FDR"
-
-    if ptype != "":
-        if ptype in ["NONE", "SELF"]:
-            pass
-        elif ptype == "BACK":
-            if doControl and doRevBackground:
-                pass
-            else:
-                print "must have a control dataset and -revbackground for pValue type 'back'"
-        else:
-            print "could not use pValue type : %s" % ptype
-
-    if withFlag != "":
-        print "restrict to flag = %s" % withFlag
-
-    if not useMulti:
-        print "using unique reads only"
-
-    if rnaSettings:
-        print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
-
-    if strandfilter == "plus":
-        print "only analyzing reads on the plus strand"
-    elif strandfilter == "minus":
-        print "only analyzing reads on the minus strand"
-
-
 def openRDSFile(filename, cachePages=None, doCache=False):
     rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
     if cachePages > rds.getDefaultCacheSize():
 def openRDSFile(filename, cachePages=None, doCache=False):
     rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
     if cachePages > rds.getDefaultCacheSize():
@@ -402,30 +431,16 @@ def openRDSFile(filename, cachePages=None, doCache=False):
     return rds
 
 
     return rds
 
 
-def getHeader(normalize, doDirectionality, listPeak, reportshift, doPvalue):
-    if normalize:
-        countType = "RPM"
-    else:
-        countType = "COUNT"
-
-    headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
-
-    if doDirectionality:
-        headerFields.append("plus%\tleftPlus%")
-
-    if listPeak:
-        headerFields.append("peakPos\tpeakHeight")
-
-    if reportshift:
-        headerFields.append("readShift")
-
-    if doPvalue:
-        headerFields.append("pValue")
+def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl):
+    print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, doCache, pValueType, controlfile, doControl)
+    header = regionFinder.getHeader(doPvalue)
+    print >> outfile, header
 
 
-    return string.join(headerFields, "\t")
+    return header
 
 
 
 
-def getChromosomeListToProcess(hitChromList, controlRDS, doControl):
+def getChromosomeListToProcess(hitRDS, controlRDS=None, doControl=False):
+    hitChromList = hitRDS.getChromosomes()
     if doControl:
         controlChromList = controlRDS.getChromosomes()
         chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
     if doControl:
         controlChromList = controlRDS.getChromosomes()
         chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
@@ -435,8 +450,8 @@ def getChromosomeListToProcess(hitChromList, controlRDS, doControl):
     return chromosomeList
 
 
     return chromosomeList
 
 
-def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, outfilename,
-                    outfile, stringency, normalize, useMulti, doControl, withFlag, combine5p, factor, listPeak):
+def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
+                    outfile, useMulti, doControl, controlRDS, combine5p):
 
     outregions = []
     allregions = []
 
     outregions = []
     allregions = []
@@ -448,14 +463,10 @@ def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, o
     reads = []
     numStarts = 0
     badRegion = False
     reads = []
     numStarts = 0
     badRegion = False
-    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
+    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
                                   strand=regionFinder.stranded, combine5p=combine5p)
 
     maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
                                   strand=regionFinder.stranded, combine5p=combine5p)
 
     maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
-    if regionFinder.shiftValue == "learn":
-        learnShift(regionFinder, hitDict, controlRDS, chromosome, maxCoord, logfilename, outfilename,
-                                outfile, numStarts, stringency, normalize, useMulti, doControl)
-
     for read in hitDict[chromosome]:
         pos = read["start"]
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
     for read in hitDict[chromosome]:
         pos = read["start"]
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
@@ -463,54 +474,52 @@ def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, o
             lastBasePosition = lastReadPos + regionFinder.readlen - 1
             newRegionIndex = regionFinder.statistics["index"] + 1
             if regionFinder.doDirectionality:
             lastBasePosition = lastReadPos + regionFinder.readlen - 1
             newRegionIndex = regionFinder.statistics["index"] + 1
             if regionFinder.doDirectionality:
-                region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=factor, numReads=totalWeight)
+                region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel,
+                                                  numReads=totalWeight)
             else:
             else:
-                region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=factor, numReads=totalWeight)
+                region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel, numReads=totalWeight)
 
 
-            if normalize:
+            if regionFinder.normalize:
                 region.numReads /= regionFinder.sampleRDSsize
 
             allregions.append(int(region.numReads))
             regionLength = lastReadPos - region.start
                 region.numReads /= regionFinder.sampleRDSsize
 
             allregions.append(int(region.numReads))
             regionLength = lastReadPos - region.start
-            if regionPassesCriteria(region.numReads, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen):
-                region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos,
-                                                useMulti, doControl, normalize)
+            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
+                region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
 
                 if region.foldRatio >= regionFinder.minRatio:
                     # first pass, with absolute numbers
                     peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
                     if regionFinder.doTrim:
                         try:
 
                 if region.foldRatio >= regionFinder.minRatio:
                     # first pass, with absolute numbers
                     peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
                     if regionFinder.doTrim:
                         try:
-                            lastReadPos = trimRegion(region, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.readlen, regionFinder.doDirectionality,
-                                                     normalize, regionFinder.sampleRDSsize)
+                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
                         except IndexError:
                             badRegion = True
                             continue
 
                         except IndexError:
                             badRegion = True
                             continue
 
-                        region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos,
-                                                        useMulti, doControl, normalize)
+                        region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
 
                     # just in case it changed, use latest data
                     try:
                         bestPos = peak.topPos[0]
                         peakScore = peak.smoothArray[bestPos]
 
                     # just in case it changed, use latest data
                     try:
                         bestPos = peak.topPos[0]
                         peakScore = peak.smoothArray[bestPos]
-                        if normalize:
+                        if regionFinder.normalize:
                             peakScore /= regionFinder.sampleRDSsize
                     except:
                         continue
 
                             peakScore /= regionFinder.sampleRDSsize
                     except:
                         continue
 
-                    if listPeak:
+                    if regionFinder.listPeak:
                         region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)
 
                     if useMulti:
                         region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)
 
                     if useMulti:
-                        setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos, normalize, regionFinder.doTrim)
+                        setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
+                                               regionFinder.normalize, regionFinder.doTrim)
 
                     region.shift = peak.shift
                     # check that we still pass threshold
                     regionLength = lastReadPos - region.start
                     plusRatio = float(peak.numPlus)/peak.numHits
 
                     region.shift = peak.shift
                     # check that we still pass threshold
                     regionLength = lastReadPos - region.start
                     plusRatio = float(peak.numPlus)/peak.numHits
-                    if regionAndPeakPass(region, regionFinder.minHits, regionFinder.minRatio, regionLength, regionFinder.readlen, peakScore, regionFinder.minPeak,
-                                         regionFinder.minPlusRatio, regionFinder.maxPlusRatio, plusRatio):
+                    if regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
                         try:
                             updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
                             regionFinder.statistics["index"] += 1
                         try:
                             updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
                             regionFinder.statistics["index"] += 1
@@ -543,8 +552,8 @@ def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, o
     return allregions, outregions
 
 
     return allregions, outregions
 
 
-def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normalize, useMulti, factor):
-
+def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti):
+    #TODO: this is *almost* the same calculation - there are small yet important differences
     print "calculating background..."
     previousHit = - 1 * regionFinder.maxSpacing
     currentHitList = [-1]
     print "calculating background..."
     previousHit = - 1 * regionFinder.maxSpacing
     currentHitList = [-1]
@@ -560,16 +569,16 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normaliz
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
             lastReadPos = currentHitList[-1]
             lastBasePosition = lastReadPos + regionFinder.readlen - 1
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
             lastReadPos = currentHitList[-1]
             lastBasePosition = lastReadPos + regionFinder.readlen - 1
-            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=factor, numReads=currentTotalWeight)
-            if normalize:
+            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
+            if regionFinder.normalize:
                 region.numReads /= regionFinder.controlRDSsize
 
             backregions.append(int(region.numReads))
                 region.numReads /= regionFinder.controlRDSsize
 
             backregions.append(int(region.numReads))
-            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=factor, numReads=currentTotalWeight)
+            region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
             regionLength = lastReadPos - region.start
             regionLength = lastReadPos - region.start
-            if regionPassesCriteria(region.numReads, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen):
+            if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
                 numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
                 numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
-                if normalize:
+                if regionFinder.normalize:
                     numMock /= regionFinder.sampleRDSsize
 
                 foldRatio = region.numReads / numMock
                     numMock /= regionFinder.sampleRDSsize
 
                 foldRatio = region.numReads / numMock
@@ -580,14 +589,13 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normaliz
 
                     if regionFinder.doTrim:
                         try:
 
                     if regionFinder.doTrim:
                         try:
-                            lastReadPos = trimRegion(region, peak, lastReadPos, 20., currentReadList, regionFinder.readlen, regionFinder.doDirectionality,
-                                                     normalize, regionFinder.controlRDSsize)
+                            lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, 20., currentReadList, regionFinder.controlRDSsize)
                         except IndexError:
                             badRegion = True
                             continue
 
                         numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
                         except IndexError:
                             badRegion = True
                             continue
 
                         numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
-                        if normalize:
+                        if regionFinder.normalize:
                             numMock /= regionFinder.sampleRDSsize
 
                         foldRatio = region.numReads / numMock
                             numMock /= regionFinder.sampleRDSsize
 
                         foldRatio = region.numReads / numMock
@@ -600,13 +608,13 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normaliz
                         continue
 
                     # normalize to RPM
                         continue
 
                     # normalize to RPM
-                    if normalize:
+                    if regionFinder.normalize:
                         peakScore /= regionFinder.controlRDSsize
 
                     # check that we still pass threshold
                     regionLength = lastReadPos - region.start
                         peakScore /= regionFinder.controlRDSsize
 
                     # check that we still pass threshold
                     regionLength = lastReadPos - region.start
-                    if regionPassesCriteria(region.numReads, regionFinder.minHits, foldRatio, regionFinder.minRatio, regionLength, regionFinder.readlen):
-                        updateControlStatistics(regionFinder, peak, region.numReads, peakScore)
+                    if regionPassesCriteria(regionFinder, region.numReads, foldRatio, regionLength):
+                        regionFinder.updateControlStatistics(peak, region.numReads, peakScore)
 
             currentHitList = []
             currentTotalWeight = 0
 
             currentHitList = []
             currentTotalWeight = 0
@@ -628,27 +636,33 @@ def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normaliz
     return backregions
 
 
     return backregions
 
 
-def learnShift(regionFinder, hitDict, controlRDS, chromosome, maxCoord, logfilename, outfilename,
-               outfile, numStarts, stringency, normalize, useMulti, doControl):
+def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
+               outfile, useMulti, doControl, controlRDS, combine5p):
+
+    hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
+                                  strand=regionFinder.stranded, combine5p=combine5p)
 
 
+    maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
     print "learning shift.... will need at least 30 training sites"
     print "learning shift.... will need at least 30 training sites"
+    stringency = regionFinder.stringency
     previousHit = -1 * regionFinder.maxSpacing
     positionList = [-1]
     totalWeight = 0
     readList = []
     shiftDict = {}
     count = 0
     previousHit = -1 * regionFinder.maxSpacing
     positionList = [-1]
     totalWeight = 0
     readList = []
     shiftDict = {}
     count = 0
+    numStarts = 0
     for read in hitDict[chromosome]:
         pos = read["start"]
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
     for read in hitDict[chromosome]:
         pos = read["start"]
         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
-            if normalize:
+            if regionFinder.normalize:
                 totalWeight /= regionFinder.sampleRDSsize
 
             regionStart = positionList[0]
             regionStop = positionList[-1]
             regionLength = regionStop - regionStart
                 totalWeight /= regionFinder.sampleRDSsize
 
             regionStart = positionList[0]
             regionStop = positionList[-1]
             regionLength = regionStop - regionStart
-            if regionPassesCriteria(totalWeight, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen, stringency=stringency):
-                foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl, normalize)
+            if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
+                foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl)
                 if foldRatio >= regionFinder.minRatio:
                     updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
                     count += 1
                 if foldRatio >= regionFinder.minRatio:
                     updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
                     count += 1
@@ -685,15 +699,15 @@ def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
     return abs(pos - previousHit) > maxSpacing or pos == maxCoord
 
 
     return abs(pos - previousHit) > maxSpacing or pos == maxCoord
 
 
-def regionPassesCriteria(sumAll, minHits, numStarts, minRatio, regionLength, readlen, stringency=1):
-    return sumAll >= stringency * minHits and numStarts > stringency * minRatio and regionLength > stringency * readlen
+def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringency=1):
+    return sumAll >= stringency * regionFinder.minHits and numStarts > stringency * regionFinder.minRatio and regionLength > stringency * regionFinder.readlen
 
 
 
 
-def trimRegion(region, peak, regionStop, trimValue, currentReadList, readlen, doDirectionality, normalize, hitRDSsize):
+def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
     bestPos = peak.topPos[0]
     peakScore = peak.smoothArray[bestPos]
     bestPos = peak.topPos[0]
     peakScore = peak.smoothArray[bestPos]
-    if normalize:
-        peakScore /= hitRDSsize
+    if regionFinder.normalize:
+        peakScore /= totalReadCount
 
     minSignalThresh = trimValue * peakScore
     start = findStartEdgePosition(peak, minSignalThresh)
 
     minSignalThresh = trimValue * peakScore
     start = findStartEdgePosition(peak, minSignalThresh)
@@ -702,18 +716,21 @@ def trimRegion(region, peak, regionStop, trimValue, currentReadList, readlen, do
 
     regionStop = region.start + stop
     region.start += start
 
     regionStop = region.start + stop
     region.start += start
-    trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, readlen, doWeight=True, leftPlus=doDirectionality, shift=peak.shift)
+
+    trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, regionFinder.readlen, doWeight=True,
+                           leftPlus=regionFinder.doDirectionality, shift=peak.shift)
+
     peak.numPlus = trimmedPeak.numPlus
     peak.numLeftPlus = trimmedPeak.numLeftPlus
     peak.topPos = trimmedPeak.topPos
     peak.smoothArray = trimmedPeak.smoothArray
 
     region.numReads = trimmedPeak.numHits
     peak.numPlus = trimmedPeak.numPlus
     peak.numLeftPlus = trimmedPeak.numLeftPlus
     peak.topPos = trimmedPeak.topPos
     peak.smoothArray = trimmedPeak.smoothArray
 
     region.numReads = trimmedPeak.numHits
-    if normalize:
-        region.numReads /= hitRDSsize
+    if regionFinder.normalize:
+        region.numReads /= totalReadCount
 
 
-    region.stop = regionStop + readlen - 1
-                                
+    region.stop = regionStop + regionFinder.readlen - 1
+                          
     return regionStop
 
 
     return regionStop
 
 
@@ -736,10 +753,13 @@ def peakEdgeLocated(peak, position, minSignalThresh):
     return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
 
 
     return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
 
 
-def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl, normalize):
+def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl):
+    """ Fold ratio calculated is total read weight over control
+    """
+    #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
     if doControl:
         numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
     if doControl:
         numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
-        if normalize:
+        if regionFinder.normalize:
             numMock /= regionFinder.controlRDSsize
 
         foldRatio = sumAll / numMock
             numMock /= regionFinder.controlRDSsize
 
         foldRatio = sumAll / numMock
@@ -770,22 +790,6 @@ def getShiftValue(shiftDict, count, logfilename, outfilename):
     return shiftValue
 
 
     return shiftValue
 
 
-def updateControlStatistics(regionFinder, peak, sumAll, peakScore):
-
-    plusRatio = float(peak.numPlus)/peak.numHits
-    if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
-        if regionFinder.doDirectionality:
-            if regionFinder.leftPlusRatio < peak.numLeft / peak.numPlus:
-                regionFinder.statistics["mIndex"] += 1
-                regionFinder.statistics["mTotal"] += sumAll
-            else:
-                regionFinder.statistics["failed"] += 1
-        else:
-            # we have a region, but didn't check for directionality
-            regionFinder.statistics["mIndex"] += 1
-            regionFinder.statistics["mTotal"] += sumAll
-
-
 def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
               peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
               numPlus, plusRatio):
 def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
               peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
               numPlus, plusRatio):
@@ -828,18 +832,16 @@ def setMultireadPercentage(region, hitRDS, hitRDSsize, currentTotalWeight, curre
     region.multiP = multiP
 
 
     region.multiP = multiP
 
 
-def regionAndPeakPass(region, minHits, minRatio, regionLength, readlen, peakScore, minPeak, minPlusRatio, maxPlusRatio, plusRatio):
+def regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
     regionPasses = False
     regionPasses = False
-    if regionPassesCriteria(region.numReads, minHits, region.foldRatio, minRatio, regionLength, readlen):
-        if peakScore >= minPeak and minPlusRatio <= plusRatio <= maxPlusRatio:
+    if regionPassesCriteria(regionFinder, region.numReads, region.foldRatio, regionLength):
+        if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
             regionPasses = True
 
     return regionPasses
 
 
             regionPasses = True
 
     return regionPasses
 
 
-def updateRegion(region,
-                 doDirectionality, leftPlusRatio, numLeft,
-                 numPlus, plusRatio):
+def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plusRatio):
 
     if doDirectionality:
         if leftPlusRatio < numLeft / numPlus:
 
     if doDirectionality:
         if leftPlusRatio < numLeft / numPlus:
@@ -849,14 +851,14 @@ def updateRegion(region,
             raise RegionDirectionError
 
 
             raise RegionDirectionError
 
 
-def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
+def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                                 allregions, header):
 
                                 allregions, header):
 
-    writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
+    writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                            allregions, header, backregions=[], pValueType="self")
 
 
                            allregions, header, backregions=[], pValueType="self")
 
 
-def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
+def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
                            allregions, header, backregions=[], pValueType="none"):
 
     print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
                            allregions, header, backregions=[], pValueType="none"):
 
     print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
@@ -867,7 +869,7 @@ def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportsh
             poissonmean = calculatePoissonMean(backregions)
 
     print header
             poissonmean = calculatePoissonMean(backregions)
 
     print header
-    writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=reportshift, shiftDict=shiftDict)
+    writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=regionFinder.reportshift, shiftDict=shiftDict)
 
 
 def calculatePoissonMean(dataList):
 
 
 def calculatePoissonMean(dataList):
@@ -905,7 +907,6 @@ def writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=0, repor
 
 def calculatePValue(sum, poissonmean):
     pValue = math.exp(-poissonmean)
 
 def calculatePValue(sum, poissonmean):
     pValue = math.exp(-poissonmean)
-    #TODO: 798: DeprecationWarning: integer argument expected, got float - for i in xrange(sum)
     for i in xrange(sum):
         pValue *= poissonmean
         pValue /= i+1
     for i in xrange(sum):
         pValue *= poissonmean
         pValue /= i+1
@@ -922,7 +923,7 @@ def getRegionString(region, reportShift):
     return outline
 
 
     return outline
 
 
-def getFooter(regionFinder, shiftDict, reportshift, doRevBackground):
+def getFooter(regionFinder, shiftDict, doRevBackground):
     index = regionFinder.statistics["index"]
     mIndex = regionFinder.statistics["mIndex"]
     footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
     index = regionFinder.statistics["index"]
     mIndex = regionFinder.statistics["mIndex"]
     footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
@@ -937,7 +938,7 @@ def getFooter(regionFinder, shiftDict, reportshift, doRevBackground):
 
         footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))
 
 
         footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))
 
-    if regionFinder.shiftValue == "auto" and reportshift:
+    if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
         bestShift = getBestShiftInDict(shiftDict)
         footerLines.append("#mode of shift values: %d" % bestShift)
 
         bestShift = getBestShiftInDict(shiftDict)
         footerLines.append("#mode of shift values: %d" % bestShift)