rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
[erange.git] / MakeRdsFromBam.py
index 969d4ccf556524200904e6f78f53ddc1cd1018b2..13c901332b2c54dcf6b4033a6633ec0e074da782 100644 (file)
@@ -21,7 +21,7 @@ from commoncode import writeLog, getConfigParser, getConfigBoolOption, getConfig
 import ReadDataset
 
 INSERT_SIZE = 100000
-verstring = "makeRdsFromBam: version 1.0"
+verstring = "makeRdsFromBam: version 1.1"
 
 
 def main(argv=None):
@@ -54,8 +54,13 @@ def main(argv=None):
         print "no outrdsfile specified - see --help for usage"
         sys.exit(1)
 
+    if options.rnaDataType:
+        dataType = "RNA"
+    else:
+        dataType = "DNA"
+
     makeRdsFromBam(label, samFileName, outDbName, options.init, options.doIndex, options.useSamFile,
-                   options.cachePages, options.maxMultiReadCount, options.rnaDataType, options.trimReadID)
+                   options.cachePages, options.maxMultiReadCount, dataType, options.trimReadID)
 
 
 def getParser(usage):
@@ -92,26 +97,14 @@ def getParser(usage):
 
 
 def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useSamFile=False,
-                   cachePages=100000, maxMultiReadCount=10, rnaDataType=False, trimReadID=True):
+                   cachePages=100000, maxMultiReadCount=10, dataType="DNA", trimReadID=True):
 
     if useSamFile:
         fileMode = "r"
     else:
         fileMode = "rb"
 
-    try:
-        samfile = pysam.Samfile(samFileName, fileMode)
-    except ValueError:
-        print "samfile index not found"
-        sys.exit(1)
-
-    if rnaDataType:
-        dataType = "RNA"
-    else:
-        dataType = "DNA"
-
     writeLog("%s.log" % outDbName, verstring, string.join(sys.argv[1:]))
-
     rds = ReadDataset.ReadDataset(outDbName, init, dataType, verbose=True)
     if not init and doIndex:
         try:
@@ -145,7 +138,8 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
                        "unique": 0,
                        "multi": 0,
                        "multiDiscard": 0,
-                       "splice": 0
+                       "splice": 0,
+                       "multisplice": 0
     }
 
     readsize = 0
@@ -153,14 +147,25 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
     uniqueInsertList = []
     multiInsertList = []
     spliceInsertList = []
+    multispliceInsertList = []
 
-    processedEntryDict = {}
     uniqueReadDict = {}
     multiReadDict = {}
+    multispliceReadDict = {}
     spliceReadDict = {}
+    multireadCounts = getMultiReadIDCounts(samFileName, fileMode)
 
-    samFileIterator = samfile.fetch(until_eof=True)
+    for readID in multireadCounts:
+        if multireadCounts[readID] > maxMultiReadCount:
+            totalReadCounts["multiDiscard"] += 1
+
+    try:
+        samfile = pysam.Samfile(samFileName, fileMode)
+    except ValueError:
+        print "samfile index not found"
+        sys.exit(1)
 
+    samFileIterator = samfile.fetch(until_eof=True)
     for read in samFileIterator:
         if read.is_unmapped:
             totalReadCounts["unmapped"] += 1
@@ -172,39 +177,28 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
             if init:
                 rds.insertMetadata([("readsize", readsize)])
 
-        #Build the read dictionaries
-        try:
-            readSequence = read.seq
-        except KeyError:
-            readSequence = ""
-
         pairReadSuffix = getPairedReadNumberSuffix(read)
-        readName = "%s%s%s" % (read.qname, readSequence, pairReadSuffix)
+        readName = "%s%s" % (read.qname, pairReadSuffix)
         if trimReadID:
             rdsEntryName = "%s:%s:%d%s" % (label, read.qname, totalReadCounts["total"], pairReadSuffix)
         else:
             rdsEntryName = read.qname
 
-        if processedEntryDict.has_key(readName):
-            if isSpliceEntry(read.cigar):
-                if spliceReadDict.has_key(readName):
-                    del spliceReadDict[readName]
-            else:
-                if uniqueReadDict.has_key(readName):
-                    del uniqueReadDict[readName]
-
-                if multiReadDict.has_key(readName):
-                    (read, priorCount, rdsEntryName) = multiReadDict[readName]
-                    count = priorCount + 1
-                    multiReadDict[readName] = (read, count, rdsEntryName)
-                else:
-                    multiReadDict[readName] = (read, 1, rdsEntryName)
-        else:
-            processedEntryDict[readName] = ""
+        try:
+            count = multireadCounts[readName]
+        except KeyError:
+            count = 1
+
+        if count == 1:
             if isSpliceEntry(read.cigar):
                 spliceReadDict[readName] = (read,rdsEntryName)
             else:
                 uniqueReadDict[readName] = (read, rdsEntryName)
+        elif count <= maxMultiReadCount:
+            if isSpliceEntry(read.cigar):
+                multispliceReadDict[readName] = (read, count, rdsEntryName)
+            else:
+                multiReadDict[readName] = (read, count, rdsEntryName)
 
         if totalReadCounts["total"] % INSERT_SIZE == 0:
             for entry in uniqueReadDict.keys():
@@ -213,20 +207,23 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
                 uniqueInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize))
                 totalReadCounts["unique"] += 1
 
-            for entry in spliceReadDict.keys():
-                (readData, rdsEntryName) = spliceReadDict[entry]
-                chrom = samfile.getrname(readData.rname)
-                spliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize))
-                totalReadCounts["splice"] += 1
-
             for entry in multiReadDict.keys():
                 (readData, count, rdsEntryName) = multiReadDict[entry]
                 chrom = samfile.getrname(readData.rname)
-                if count > maxMultiReadCount:
-                    totalReadCounts["multiDiscard"] += 1
-                else:
-                    multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count)) 
-                    totalReadCounts["multi"] += 1
+                multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count)) 
+
+            if dataType == "RNA":
+                for entry in spliceReadDict.keys():
+                    (readData, rdsEntryName) = spliceReadDict[entry]
+                    chrom = samfile.getrname(readData.rname)
+                    spliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize))
+                    totalReadCounts["splice"] += 1
+
+                for entry in multispliceReadDict.keys():
+                    (readData, count, rdsEntryName) = multispliceReadDict[entry]
+                    chrom = samfile.getrname(readData.rname)
+                    multispliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize, weight=count))
+                    totalReadCounts["multisplice"] += 1
 
             rds.insertUniqs(uniqueInsertList)
             rds.insertMulti(multiInsertList)
@@ -238,10 +235,12 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
                 rds.insertSplices(spliceInsertList)
                 spliceInsertList = []
                 spliceReadDict = {}
+                rds.insertMultisplices(multispliceInsertList)
+                multispliceInsertList = []
+                multispliceReadDict = {}
 
             print ".",
             sys.stdout.flush()
-            processedEntryDict = {}
 
         totalReadCounts["total"] += 1
 
@@ -258,13 +257,9 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
         for entry in multiReadDict.keys():
             (readData, count, rdsEntryName) = multiReadDict[entry]
             chrom = samfile.getrname(readData.rname)
-            if count > maxMultiReadCount:
-                totalReadCounts["multiDiscard"] += 1
-            else:
-                multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count))
-                totalReadCounts["multi"] += 1
+            multiInsertList.append(getRDSEntry(readData, rdsEntryName, chrom, readsize, weight=count))
 
-        totalReadCounts["multi"] += len(multiInsertList)
+        rds.insertMulti(multiInsertList)
 
     if len(spliceReadDict.keys()) > 0 and dataType == "RNA":
         for entry in spliceReadDict.keys():
@@ -275,12 +270,23 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
 
         rds.insertSplices(spliceInsertList)
 
+    if len(multispliceReadDict.keys()) > 0 and dataType == "RNA":
+        for entry in multispliceReadDict.keys():
+            (readData, count, rdsEntryName) = multispliceReadDict[entry]
+            chrom = samfile.getrname(readData.rname)
+            multispliceInsertList.append(getRDSSpliceEntry(readData, rdsEntryName, chrom, readsize, weight=count))
+            totalReadCounts["multisplice"] += 1
+
+        rds.insertMultisplices(multispliceInsertList)
+
+    totalReadCounts["multi"] = len(multireadCounts) - totalReadCounts["multiDiscard"] - totalReadCounts["multisplice"]
     countStringList = ["\n%d unmapped reads discarded" % totalReadCounts["unmapped"]]
     countStringList.append("%d unique reads" % totalReadCounts["unique"])
     countStringList.append("%d multi reads" % totalReadCounts["multi"])
     countStringList.append("%d multi reads count > %d discarded" % (totalReadCounts["multiDiscard"], maxMultiReadCount))
     if dataType == "RNA":
         countStringList.append("%d spliced reads" % totalReadCounts["splice"])
+        countStringList.append("%d spliced multireads" % totalReadCounts["multisplice"])
 
     print string.join(countStringList, "\n")
     outputCountText = string.join(countStringList, "\t")
@@ -295,6 +301,29 @@ def makeRdsFromBam(label, samFileName, outDbName, init=True, doIndex=False, useS
             rds.buildIndex(defaultCacheSize)
 
 
+def getMultiReadIDCounts(samFileName, fileMode):
+    try:
+        samfile = pysam.Samfile(samFileName, fileMode)
+    except ValueError:
+        print "samfile index not found"
+        sys.exit(1)
+
+    readIDCounts = {}
+    for read in samfile.fetch(until_eof=True):
+        pairReadSuffix = getPairedReadNumberSuffix(read)
+        readName = "%s%s" % (read.qname, pairReadSuffix)
+        try:
+            readIDCounts[readName] += 1
+        except KeyError:
+            readIDCounts[readName] = 1
+
+    for readID in readIDCounts.keys():
+        if readIDCounts[readID] == 1:
+            del readIDCounts[readID]
+
+    return readIDCounts
+
+
 def getRDSEntry(alignedRead, readName, chrom, readSize, weight=1):
     start = int(alignedRead.pos)
     stop = int(start + readSize)
@@ -308,11 +337,11 @@ def getRDSEntry(alignedRead, readName, chrom, readSize, weight=1):
     return (readName, chrom, start, stop, sense, 1.0/weight, '', mismatches)
 
 
-def getRDSSpliceEntry(alignedRead, readName, chrom, readSize):
-    (readName, chrom, start, stop, sense, weight, flag, mismatches) = getRDSEntry(alignedRead, readName, chrom, readSize)
+def getRDSSpliceEntry(alignedRead, readName, chrom, readSize, weight=1):
+    (readName, chrom, start, stop, sense, weight, flag, mismatches) = getRDSEntry(alignedRead, readName, chrom, readSize, weight)
     startL, startR, stopL, stopR = getSpliceBounds(start, readSize, alignedRead.cigar)
     
-    return (readName, chrom, startL, stopL, startR, stopR, sense, 1.0, "", mismatches)
+    return (readName, chrom, startL, stopL, startR, stopR, sense, weight, "", mismatches)
 
 
 def getPairedReadNumberSuffix(read):