erange version 4.0a dev release
[erange.git] / weighMultireads.py
index ed27edf77b6c3d2912b00b0c4c64d6fb7f50167a..f4b1691a07a08b15988539f75f1fbb2a62e148b2 100755 (executable)
@@ -12,10 +12,15 @@ try:
 except:
     pass
 
-from commoncode import readDataset
-import sys, time, string, optparse
+import sys
+import time
+import string
+import optparse
+import ReadDataset
+from commoncode import getConfigParser, getConfigBoolOption, getConfigOption
 
-print "%prog: version 3.1"
+
+print "weighMultireads: version 3.3"
 
 def main(argv=None):
     if not argv:
@@ -23,13 +28,7 @@ def main(argv=None):
 
     usage = "usage: python %s rdsfile [--radius bp] [--noradius] [--usePairs maxDist] [--verbose] [--cache pages]"
 
-    parser = optparse.OptionParser(usage=usage)
-    parser.add_option("--radius", type="int", dest="radius")
-    parser.add_option("--noradius", action="store_false", dest="doRadius")
-    parser.add_option("--usePairs", type="int", dest="pairDist")
-    parser.add_option("--verbose", action="store_true", dest="verbose")
-    parser.add_option("--cache", type="int", dest="cachePages")
-    parser.set_defaults(radius=None, doRadius=True, pairDist=None, verbose=False, cachePages=None)
+    parser = getParser(usage)
     (options, args) = parser.parse_args(argv[1:])
 
     if len(args) < 1:
@@ -41,260 +40,294 @@ def main(argv=None):
     weighMultireads(rdsfile, options.radius, options.doRadius, options.pairDist, options.verbose, options.cachePages)
 
 
-def weighMultireads(rdsfile, radius=None, doRadius=True, pairDist=None, verbose=False, cachePages=None):
+def getParser(usage):
+    parser = optparse.OptionParser(usage=usage)
+    parser.add_option("--radius", type="int", dest="radius")
+    parser.add_option("--noradius", action="store_false", dest="doRadius")
+    parser.add_option("--usePairs", type="int", dest="pairDist")
+    parser.add_option("--verbose", action="store_true", dest="verbose")
+    parser.add_option("--cache", type="int", dest="cachePages")
 
-    if radius is not None:
-        doRadius = True
-    else:
-        radius = 100
+    configParser = getConfigParser()
+    section = "weighMultireads"
+    radius = getConfigOption(configParser, section, "radius", None)
+    doRadius = getConfigBoolOption(configParser, section, "doRadius", True)
+    pairDist = getConfigOption(configParser, section, "pairDist", None)
+    verbose = getConfigBoolOption(configParser, section, "verbose", False)
+    cachePages = getConfigOption(configParser, section, "cachePages", None)
+    
+    parser.set_defaults(radius=radius, doRadius=doRadius, pairDist=pairDist, verbose=verbose, cachePages=cachePages)
 
-    usePairs = False
-    if pairDist is not None:
-        usePairs = True
+    return parser
+
+
+def weighMultireads(rdsfile, radius=None, doRadius=True, pairDist=None, verbose=False, cachePages=None):
 
-    tooFar = pairDist * 10
-    
-    doCache = False
     if cachePages is not None:
         doCache = True
     else:
+        doCache = False
         cachePages = 1
 
-    RDS = readDataset(rdsfile, verbose = True, cache=doCache)
-    readlen = RDS.getReadSize()
-    halfreadlen = readlen / 2
-
+    RDS = ReadDataset.ReadDataset(rdsfile, verbose = True, cache=doCache)
     if cachePages > RDS.getDefaultCacheSize():
         RDS.setDBcache(cachePages)
 
     if verbose:
         print time.ctime()
 
-    multiIDs = RDS.getReadIDs(uniqs=False,multi=True)
+    multiIDs = RDS.getReadIDs(uniqs=False, multi=True)
     if verbose:
         print "got multiIDs ", time.ctime()
 
-    fixedPair = 0
     fixedReads = []
-    if usePairs:
-        print "doing pairs with pairDist = %d" % pairDist
-        uidDict = {}
-        midDict = {}
-        jointList = []
-        bothMultiList = []
-        mainIDList = []
-        guDict = {}
-        muDict = {}
-
-        if RDS.dataType == "RNA":
-            uniqIDs = RDS.getReadIDs(uniqs=True,multi=False,splices=True)
-        else:
-            uniqIDs = RDS.getReadIDs(uniqs=True,multi=False,splices=False)
+    if pairDist is not None:
+        fixedReads = reweighUsingPairs(RDS, pairDist, multiIDs, verbose)
 
-        if verbose:
-            print "got uniqIDs ", time.ctime()
+    if radius is not None:
+        doRadius = True
+    else:
+        radius = 100
 
-        for readID in uniqIDs:
-            (mainID, pairID) = readID.split("/")
-            try:
-                uidDict[mainID].append(pairID)
-            except:
-                uidDict[mainID] = [pairID]
-                mainIDList.append(mainID)
+    if doRadius:
+        reweighUsingRadius(RDS, radius, multiIDs, fixedReads, verbose)
 
-        if verbose:
-            print "uidDict all ", len(uidDict), time.ctime()
+    if doCache:
+        RDS.saveCacheDB(rdsfile)
 
-        for mainID in mainIDList:
-            if len(uidDict[mainID]) == 2:
-                del uidDict[mainID]
+    if verbose:
+        print "finished", time.ctime()
 
-        if verbose:
-            print "uidDict first candidates ", len(uidDict), time.ctime()
 
-        for readID in multiIDs:
-            (frontID, multiplicity) = readID.split("::")
-            (mainID, pairID) = frontID.split("/")
-            try:
-                if pairID not in midDict[mainID]:
-                    midDict[mainID].append(pairID)
-            except:
-                midDict[mainID] = [pairID]
+def reweighUsingPairs(RDS, pairDist, multiIDs, verbose=False):
+    fixedPair = 0
+    tooFar = pairDist * 10
+    readlen = RDS.getReadSize()
+    fixedReads = []
+    print "doing pairs with pairDist = %d" % pairDist
+    hasSplices = RDS.dataType == "RNA"
+    uniqIDs = RDS.getReadIDs(uniqs=True, multi=False, splices=hasSplices)
+
+    if verbose:
+        print "got uniqIDs ", time.ctime()
 
-        if verbose:
-            print "all multis ", len(midDict), time.ctime()
+    jointList, bothMultiList = getReadIDLists(uniqIDs, multiIDs, verbose)
+    uniqDict = getUniqAndSpliceReadsFromReadIDs(RDS, jointList, verbose)
+    if verbose:
+        print "guDict actual ", len(uniqDict), time.ctime()
 
-        mainIDList = uidDict.keys()
-        for mainID in mainIDList:
-            if mainID not in midDict:
-                del uidDict[mainID]
+    multiDict = getMultiReadsFromReadIDs(RDS, jointList, bothMultiList, verbose)
+    if verbose:
+        print "muDict actual ", len(multiDict), time.ctime()
+
+    RDS.setSynchronousPragma("OFF")
+    for readID in jointList:
+        try:
+            ustart = uniqDict[readID]["start"]
+            ustop = ustart + readlen
+        except KeyError:
+            ustart = uniqDict[readID]["startL"]
+            ustop = uniqDict[readID]["stopR"]
+
+        uniqReadChrom = uniqDict[readID]["chrom"]
+        multiReadList = multiDict[readID]
+        numMultiReads = len(multiReadList)
+        bestMatch = [tooFar] * numMultiReads
+        found = False
+        for index in range(numMultiReads):
+            mstart = multiReadList[index]["start"]
+            multiReadChrom = multiReadList[index]["chrom"]
+            mpair = multiReadList[index]["pairID"]
+            if uniqReadChrom != multiReadChrom:
+                continue
 
-        if verbose:
-            print "uidDict actual candidates ", len(uidDict), time.ctime()
+            if abs(mstart - ustart) < pairDist:
+                bestMatch[index] = abs(mstart - ustart)
+                found = True
+            elif abs(mstart - ustop) < pairDist:
+                bestMatch[index] = abs(mstart - ustop)
+                found = True
 
-        for readID in midDict:
-            listLen = len(midDict[readID])
-            if listLen == 1:
-                if readID in uidDict:
-                    jointList.append(readID)
-            elif listLen == 2:
-                bothMultiList.append(readID)
+        if found:
+            theMatch = -1
+            theDist = tooFar
+            reweighList = []
+            for index in range(numMultiReads):
+                if theDist > bestMatch[index]:
+                    theMatch = index
+                    theDist = bestMatch[index]
+
+            theID = string.join([readID, mpair], "/")
+            for index in range(numMultiReads):
+                if index == theMatch:
+                    score = 1 - ((numMultiReads - 1) / (100. * numMultiReads))
+                else:
+                    score = 1 / (100. * numMultiReads)
+
+                start = multiReadList[index][0]
+                chrom = "chr%s" % multiReadList[index][1]
+                reweighList.append((round(score,3), chrom, start, theID))
+
+            #TODO: Is this right? If match index is 0 are we doing nothing?
+            if theMatch > 0:
+                RDS.reweighMultireads(reweighList)
+                fixedPair += 1
+                if verbose and fixedPair % 10000 == 1:
+                    print "fixed %d" % fixedPair
+                    print uniqDict[readID]
+                    print multiDict[readID]
+                    print reweighList
+
+                fixedReads.append(theID)
+
+    RDS.setSynchronousPragma("ON")
+
+    print "fixed %d pairs" % fixedPair
+    print time.ctime()
+
+    return fixedReads
+
+
+def getReadIDLists(uniqIDs, multiIDs, verbose=False):
+    uidDict = {}
+    mainIDList = []
+    for readID in uniqIDs:
+        (mainID, pairID) = readID.split("/")
+        try:
+            uidDict[mainID].append(pairID)
+        except:
+            uidDict[mainID] = [pairID]
+            mainIDList.append(mainID)
 
-        if verbose:
-            print "joint ", len(jointList), time.ctime()
-            print "bothMulti ", len(bothMultiList), time.ctime()
+    if verbose:
+        print "uidDict all ", len(uidDict), time.ctime()
 
-        del uidDict
-        del midDict
-        del mainIDList
-        del uniqIDs
+    for mainID in mainIDList:
+        if len(uidDict[mainID]) == 2:
+            del uidDict[mainID]
 
-        uniqDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=True, readIDDict=True)
-        if verbose:
-            print "got uniq dict ", len(uniqDict), time.ctime()
-
-        if RDS.dataType == "RNA":
-            spliceDict = RDS.getSplicesDict(noSense=True, withChrom=True, withPairID=True, readIDDict=True)
-            if verbose:
-                print "got splice dict ", len(spliceDict), time.ctime()
-
-        for readID in jointList:
-            try:
-                guDict[readID] = uniqDict[readID][0]
-            except:
-                if RDS.dataType == "RNA":
-                    guDict[readID] = spliceDict[readID][0]
-
-        del uniqDict
-        del spliceDict
-        if verbose:
-            print "guDict actual ", len(guDict), time.ctime()
+    if verbose:
+        print "uidDict first candidates ", len(uidDict), time.ctime()
+
+    midDict = {}
+    for readID in multiIDs:
+        (frontID, multiplicity) = readID.split("::")
+        (mainID, pairID) = frontID.split("/")
+        try:
+            if pairID not in midDict[mainID]:
+                midDict[mainID].append(pairID)
+        except:
+            midDict[mainID] = [pairID]
 
-        multiDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=False, doMulti=True, readIDDict=True)
-        if verbose:
-            print "got multi dict ", len(multiDict), time.ctime()
+    if verbose:
+        print "all multis ", len(midDict), time.ctime()
 
-        for readID in jointList:
-            muDict[readID] = multiDict[readID]
+    mainIDList = uidDict.keys()
+    for mainID in mainIDList:
+        if mainID not in midDict:
+            del uidDict[mainID]
 
-        for readID in bothMultiList:
-            muDict[readID] = multiDict[readID]
+    if verbose:
+        print "uidDict actual candidates ", len(uidDict), time.ctime()
+
+    jointList = []
+    bothMultiList = []
+    for readID in midDict:
+        listLen = len(midDict[readID])
+        if listLen == 1:
+            if readID in uidDict:
+                jointList.append(readID)
+        elif listLen == 2:
+            bothMultiList.append(readID)
 
-        del multiDict
-        if verbose:
-            print "muDict actual ", len(muDict), time.ctime()
-
-        RDS.setSynchronousPragma("OFF")
-        for readID in jointList:
-            try:
-                (ustart, uchrom, upair) = guDict[readID]
-                ustop = ustart + readlen
-            except:
-                (ustart, lstop, rstart, ustop, uchrom, upair) = guDict[readID]
-
-            muList = muDict[readID]
-            muLen = len(muList)
-            bestMatch = [tooFar] * muLen
-            found = False
-            for index in range(muLen):
-                (mstart, mchrom, mpair) = muList[index]
-                if uchrom != mchrom:
-                    continue
-
-                if abs(mstart - ustart) < pairDist:
-                    bestMatch[index] = abs(mstart - ustart)
-                    found = True
-                elif abs(mstart - ustop) < pairDist:
-                    bestMatch[index] = abs(mstart - ustop)
-                    found = True
-
-            if found:
-                theMatch = -1
-                theDist = tooFar
-                reweighList = []
-                for index in range(muLen):
-                    if theDist > bestMatch[index]:
-                        theMatch = index
-                        theDist = bestMatch[index]
-
-                theID = string.join([readID, mpair], "/")
-                for index in range(muLen):
-                    if index == theMatch:
-                        score = 1 - (muLen - 1) / (100. * (muLen))
-                    else:
-                        score = 1 / (100. * muLen)
-
-                    start = muList[index][0]
-                    chrom = "chr%s" % muList[index][1]
-                    reweighList.append((round(score,3), chrom, start, theID))
-
-                if theMatch > 0:
-                    RDS.reweighMultireads(reweighList)
-                    fixedPair += 1
-                    if verbose and fixedPair % 10000 == 1:
-                        print "fixed %d" % fixedPair
-                        print guDict[readID]
-                        print muDict[readID]
-                        print reweighList
-
-                    fixedReads.append(theID)
-
-        RDS.setSynchronousPragma("ON")
-
-        del guDict
-        del muDict
-        print "fixed %d pairs" % fixedPair
-        print time.ctime()
+    if verbose:
+        print "joint ", len(jointList), time.ctime()
+        print "bothMulti ", len(bothMultiList), time.ctime()
 
-    skippedReads = 0
-    if doRadius:
-        print "doing uniq read radius with radius = %d" % radius
-        multiDict = RDS.getReadsDict(noSense=True, withWeight=True, withChrom=True, withID=True, doUniqs=False, doMulti=True, readIDDict=True)
-        print "got multiDict"
-        RDS.setSynchronousPragma("OFF")
-        rindex = 0
-        for readID in multiIDs:
-            theID = readID
-            if theID in fixedReads:
-                skippedReads += 1
-                continue
+    return jointList, bothMultiList
 
-            if "::" in readID:
-                (readID, multiplicity) = readID.split("::")
-
-            scores = []
-            coords = []
-            for read in multiDict[readID]:
-                (start, weight, rID, chrom) = read
-                achrom = "chr%s" % chrom
-                regionStart = start + halfreadlen - radius
-                regionStop = start + halfreadlen + radius 
-                uniqs = RDS.getCounts(achrom, regionStart, regionStop, uniqs=True, multi=False, splices=False, reportCombined=True)
-                scores.append(uniqs + 1)
-                coords.append((achrom, start, theID))
-
-            total = float(sum(scores))
-            reweighList = []
-            for index in range(len(scores)):
-                reweighList.append((round(scores[index]/total,2), coords[index][0], coords[index][1], coords[index][2]))
 
-            RDS.reweighMultireads(reweighList)
-            rindex += 1
-            if rindex % 10000 == 0:
-                print rindex
+def getUniqAndSpliceReadsFromReadIDs(RDS, jointList, verbose=False):
+    uniqReadsDict = {}
+    uniqDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=True, readIDDict=True)
+    if verbose:
+        print "got uniq dict ", len(uniqDict), time.ctime()
 
-        RDS.setSynchronousPragma("ON")
+    if RDS.dataType == "RNA":
+        spliceDict = RDS.getSplicesDict(noSense=True, withChrom=True, withPairID=True, readIDDict=True)
         if verbose:
-            print "skipped ", skippedReads
+            print "got splice dict ", len(spliceDict), time.ctime()
 
-        print "reweighted ", rindex
+    for readID in jointList:
+        try:
+            uniqReadsDict[readID] = uniqDict[readID][0]
+        except KeyError:
+            if RDS.dataType == "RNA":
+                uniqReadsDict[readID] = spliceDict[readID][0]
+
+    return uniqReadsDict
 
-    if doCache:
-        RDS.saveCacheDB(rdsfile)
 
+def getMultiReadsFromReadIDs(RDS, jointList, bothMultiList, verbose=False):
+    multiReadSubsetDict = {}
+    multiDict = RDS.getReadsDict(noSense=True, withChrom=True, withPairID=True, doUniqs=False, doMulti=True, readIDDict=True)
     if verbose:
-        print "finished", time.ctime()
-    
+        print "got multi dict ", len(multiDict), time.ctime()
+
+    for readID in jointList:
+        multiReadSubsetDict[readID] = multiDict[readID]
+
+    for readID in bothMultiList:
+        multiReadSubsetDict[readID] = multiDict[readID]
+
+    return multiReadSubsetDict
+
+
+def reweighUsingRadius(RDS, radius, multiIDs, readsToSkip=[], verbose=False):
+    skippedReads = 0
+    readlen = RDS.getReadSize()
+    halfreadlen = readlen / 2
+    print "doing uniq read radius with radius = %d" % radius
+    multiDict = RDS.getReadsDict(noSense=True, withWeight=True, withChrom=True, withID=True, doUniqs=False, doMulti=True, readIDDict=True)
+    print "got multiDict"
+    RDS.setSynchronousPragma("OFF")
+    reweighedCount = 0
+    for readID in multiIDs:
+        originalMultiReadID = readID
+        if originalMultiReadID in readsToSkip:
+            skippedReads += 1
+            continue
+
+        if "::" in readID:
+            (readID, multiplicity) = readID.split("::")
+
+        scores = []
+        coords = []
+        for read in multiDict[readID]:
+            start = read["start"]
+            chromosome = "chr%s" % read["chrom"]
+            regionStart = start + halfreadlen - radius
+            regionStop = start + halfreadlen + radius 
+            uniqs = RDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=False, splices=False, reportCombined=True)
+            scores.append(uniqs + 1)
+            coords.append((chromosome, start, originalMultiReadID))
+
+        total = float(sum(scores))
+        reweighList = []
+        for index in range(len(scores)):
+            reweighList.append((round(scores[index]/total,2), coords[index][0], coords[index][1], coords[index][2]))
+
+        RDS.reweighMultireads(reweighList)
+        reweighedCount += 1
+        if reweighedCount % 10000 == 0:
+            print reweighedCount
+
+    RDS.setSynchronousPragma("ON")
+    if verbose:
+        print "skipped ", skippedReads
+
+    print "reweighted ", reweighedCount
+
 
 if __name__ == "__main__":
     main(sys.argv)
\ No newline at end of file