erange version 4.0a dev release
[erange.git] / geneMrnaCountsWeighted.py
index 7acf0b92fd1d93da0c29d4bdbba08af6c7d2a937..38e853aec6172be26b448066692193f1537b6926 100755 (executable)
@@ -4,13 +4,15 @@ try:
 except:
     print 'psyco not running'
 
-import sys, optparse
-from commoncode import readDataset, getMergedRegions, getFeaturesByChromDict
+import sys
+import optparse
+from commoncode import getMergedRegions, getFeaturesByChromDict, getConfigParser, getConfigOption, getConfigBoolOption
+import ReadDataset
 from cistematic.genomes import Genome
-from cistematic.core.geneinfo import geneinfoDB
+from commoncode import getGeneInfoDict
 from cistematic.core import chooseDB, cacheGeneDB, uncacheGeneDB
 
-print '%s: version 4.1' % sys.argv[0]
+print "geneMrnaCountsWeighted: version 4.3"
 
 
 def main(argv=None):
@@ -19,21 +21,7 @@ def main(argv=None):
 
     usage = "usage: python %s genome rdsfile uniqcountfile outfile [options]"
 
-    parser = optparse.OptionParser(usage=usage)
-    parser.add_option("--stranded", action="store_false", dest="ignoreSense")
-    parser.add_option("--uniq", action="store_true", dest="withUniqs")
-    parser.add_option("--multi", action="store_true", dest="withMulti")
-    parser.add_option("--record", action="store_true", dest="recording",
-                      help="ignored with uniq reads")
-    parser.add_option("--accept", dest="acceptfile")
-    parser.add_option("--cache", type="int", dest="cachePages")
-    parser.add_option("--verbose", action="store_true", dest="doVerbose")
-    parser.add_option("--models", dest="extendGenome")
-    parser.add_option("--replacemodels", action="store_true", dest="replaceModels")
-    parser.set_defaults(ignoreSense=True, withUniqs=False, withMulti=False, recording=False,
-                        acceptfile=None, cachePages=None, doVerbose=False, extendGenome="",
-                        replaceModels=False)
-
+    parser = makeParser(usage)
     (options, args) = parser.parse_args(argv[1:])
 
     if len(args) < 4:
@@ -46,13 +34,48 @@ def main(argv=None):
     outfilename = args[3]
 
     geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, options.ignoreSense,
-                           options.withUniqs, options.withMulti, options.recording,
+                           options.withUniqs, options.withMulti,
                            options.acceptfile, options.cachePages, options.doVerbose,
                            options.extendGenome, options.replaceModels)
 
 
+def makeParser(usage=""):
+    parser = optparse.OptionParser(usage=usage)
+    parser.add_option("--stranded", action="store_false", dest="ignoreSense")
+    parser.add_option("--uniq", action="store_true", dest="withUniqs")
+    parser.add_option("--multi", action="store_true", dest="withMulti")
+    parser.add_option("--accept", dest="acceptfile")
+    parser.add_option("--cache", type="int", dest="cachePages")
+    parser.add_option("--verbose", action="store_true", dest="doVerbose")
+    parser.add_option("--models", dest="extendGenome")
+    parser.add_option("--replacemodels", action="store_true", dest="replaceModels")
+
+    configParser = getConfigParser()
+    section = "geneMrnaCountsWeighted"
+    ignoreSense = getConfigBoolOption(configParser, section, "ignoreSense", True)
+    withUniqs = getConfigBoolOption(configParser, section, "withUniqs", False)
+    withMulti = getConfigBoolOption(configParser, section, "withMulti", False)
+    acceptfile = getConfigOption(configParser, section, "acceptfile", None)
+    cachePages = getConfigOption(configParser, section, "cachePages", None)
+    doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
+    extendGenome = getConfigOption(configParser, section, "extendGenome", "")
+    replaceModels = getConfigBoolOption(configParser, section, "replaceModels", False)
+
+    parser.set_defaults(ignoreSense=ignoreSense, withUniqs=withUniqs, withMulti=withMulti,
+                        acceptfile=acceptfile, cachePages=cachePages, doVerbose=doVerbose, extendGenome=extendGenome,
+                        replaceModels=replaceModels)
+
+    return parser
+
+
+#TODO: Reported user performance issue. Long run times in conditions:
+#    small number of reads ~40-50M
+#    all features on single chromosome
+#
+#    User states has been a long time problem.
+
 def geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, ignoreSense=True,
-                           withUniqs=False, withMulti=False, recording=False, acceptfile=None,
+                           withUniqs=False, withMulti=False, acceptfile=None,
                            cachePages=None, doVerbose=False, extendGenome="", replaceModels=False):
 
     if (not withUniqs and not withMulti) or (withUniqs and withMulti):
@@ -62,70 +85,40 @@ def geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, ignoreSense=
     if cachePages is not None:
         cacheGeneDB(genome)
         hg = Genome(genome, dbFile=chooseDB(genome), inRAM=True)
-        idb = geneinfoDB(cache=True)
         print "%s cached" % genome
         doCache = True
     else:
         doCache = False
         cachePages = 0
         hg = Genome(genome, inRAM=True)
-        idb = geneinfoDB()
-
-    if acceptfile is not None:
-        acceptDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=True)
-    else:
-        acceptDict = {}
-
-    if recording and withUniqs:
-        recording = False
 
     if extendGenome:
         if replaceModels:
             print "will replace gene models with %s" % extendGenome
         else:
             print "will extend gene models with %s" % extendGenome
-    else:
-        replaceModels = False
 
-    if extendGenome != "":
-        hg.extendFeatures(extendGenome, replace = replaceModels)
-    
-    hitRDS = readDataset(hitfile, verbose = True, cache=doCache)
+        hg.extendFeatures(extendGenome, replace=replaceModels)
+
+    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=doVerbose, cache=doCache)
     if cachePages > hitRDS.getDefaultCacheSize():
         hitRDS.setDBcache(cachePages)
 
-    readlen = hitRDS.getReadSize()
-
-    geneinfoDict = idb.getallGeneInfo(genome)
-    geneannotDict = hg.allAnnotInfo()
-    gidCount = {}
-    gidReadDict = {}
-
-    featuresByChromDict = getFeaturesByChromDict(hg, acceptDict)
-    gidList = hg.allGIDs()
+    allGIDs = set(hg.allGIDs())
+    if acceptfile is not None:
+        regionDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=doVerbose)
+        for chrom in regionDict:
+            for region in regionDict[chrom]:
+                allGIDs.add(region.label)
+    else:
+        regionDict = {}
 
-    gidList.sort()
-    for chrom in acceptDict:
-        for (label, start, stop, length) in acceptDict[chrom]:
-            if label not in gidList:
-                gidList.append(label)
+    featuresByChromDict = getFeaturesByChromDict(hg, regionDict)
 
-    for gid in gidList:
-        gidCount[gid] = 0
-        gidReadDict[gid] = []
-
-    uniqueCountDict = {}
+    gidReadDict = {}
     read2GidDict = {}
-
-    uniquecounts = open(countfile)
-    for line in uniquecounts:
-        fields = line.strip().split()
-        # add a pseudo-count here to ease calculations below
-        uniqueCountDict[fields[0]] = float(fields[-1]) + 1
-
-    uniquecounts.close()
-
-    outfile = open(outfilename, "w")
+    for gid in allGIDs:
+        gidReadDict[gid] = []
 
     index = 0
     if withMulti and not withUniqs:
@@ -133,124 +126,98 @@ def geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, ignoreSense=
     else:
         chromList = hitRDS.getChromosomes(fullChrom=False)
 
-    for achrom in chromList:
-        if achrom not in featuresByChromDict:
+    readlen = hitRDS.getReadSize()
+    for chromosome in chromList:
+        if doNotProcessChromosome(chromosome, featuresByChromDict.keys()):
             continue
 
-        print "\n" + achrom + " ",
-        startFeature = 0
-        fullchrom = "chr" + achrom
+        print "\n%s " % chromosome,
+        fullchrom = "chr%s" % chromosome
         hitDict = hitRDS.getReadsDict(noSense=ignoreSense, fullChrom=True, chrom=fullchrom, withID=True, doUniqs=withUniqs, doMulti=withMulti)
-        featList = featuresByChromDict[achrom]
-        if ignoreSense:
-            for (tagStart, tagReadID) in hitDict[fullchrom]:
-                index += 1
-                if index % 100000 == 0:
-                    print "read %d" % index,
-
-                stopPoint = tagStart + readlen
-                if startFeature < 0:
-                    startFeature = 0
-
-                for (start, stop, gid, sense, ftype) in featList[startFeature:]:
-                    if tagStart > stop:
-                        startFeature += 1
-                        continue
-
-                    if start > stopPoint:
-                        startFeature -= 100
-                        break
-
-                    if start <= tagStart <= stop:
-                        try:
-                            gidReadDict[gid].append(tagReadID)
-                            if tagReadID in read2GidDict:
-                                if gid not in read2GidDict[tagReadID]:
-                                    read2GidDict[tagReadID].append(gid)
-                            else:
-                                read2GidDict[tagReadID] = [gid]
-
-                            gidCount[gid] += 1
-                        except:
-                            print "gid %s not in gidReadDict" % gid
-
-                        stopPoint = stop
-        else:
-            for (tagStart, tSense, tagReadID) in hitDict[fullchrom]:
-                index += 1
-                if index % 100000 == 0:
-                    print "read %d" % index,
-
-                stopPoint = tagStart + readlen
-                if startFeature < 0:
-                    startFeature = 0
-
-                for (start, stop, gid, sense, ftype) in featList[startFeature:]:
-                    if tagStart > stop:
-                        startFeature += 1
-                        continue
-
-                    if start > stopPoint:
-                        startFeature -= 100
-                        break
-
-                    if sense == "R":
-                        sense = "-"
-                    else:
-                        sense = "+"
-
-                    if start <= tagStart <= stop and sense == tSense:
-                        try:
-                            gidReadDict[gid].append(tagReadID)
-                            if tagReadID in read2GidDict:
-                                if gid not in read2GidDict[tagReadID]:
-                                    read2GidDict[tagReadID].append(gid)
-                            else:
-                                read2GidDict[tagReadID] = [gid]
-
-                            gidCount[gid] += 1
-                        except:
-                            print "gid %s not in gidReadDict" % gid
-
-                        stopPoint = stop
-
-    for gid in gidList:
-        if "FAR" not in gid:
-            symbol = "LOC" + gid
-            geneinfo = ""
+        featureList = featuresByChromDict[chromosome]
+
+        readGidList, totalProcessedReads = getReadGIDs(hitDict, fullchrom, featureList, readlen, index)
+        index = totalProcessedReads
+        for (tagReadID, gid) in readGidList:
             try:
-                geneinfo = geneinfoDict[gid]
-                if genome == "celegans":
-                    symbol = geneinfo[0][1]
+                gidReadDict[gid].append(tagReadID)
+                if tagReadID in read2GidDict:
+                    read2GidDict[tagReadID].add(gid)
                 else:
-                    symbol = geneinfo[0][0]
-            except:
-                try:
-                    symbol = geneannotDict[(genome, gid)][0]
-                except:
-                    symbol = "LOC" + gid
-        else:
-            symbol = gid
+                    read2GidDict[tagReadID] = set([gid])
+            except KeyError:
+                print "gid %s not in gidReadDict" % gid
 
-        tagCount = 0.
-        for readID in gidReadDict[gid]:
-            try:
-                tagValue = uniqueCountDict[gid]
-            except:
-                tagValue = 1
+    writeCountsToFile(outfilename, countfile, allGIDs, hg, gidReadDict, read2GidDict, doVerbose, doCache)
+    if doCache:
+        uncacheGeneDB(genome)
 
-            tagDenom = 0.
-            for aGid in read2GidDict[readID]:
-                try:
-                    tagDenom += uniqueCountDict[aGid]
-                except:
-                    tagDenom += 1
 
-        try:
-            tagCount += tagValue / tagDenom
-        except ZeroDivisionError:
-            tagCount = 0
-    
+def doNotProcessChromosome(chromosome, chromosomeList):
+    return chromosome not in chromosomeList
+
+
+def getReadGIDs(hitDict, fullchrom, featList, readlen, index):
+
+    startFeature = 0
+    readGidList = []
+    ignoreSense = True
+    for read in hitDict[fullchrom]:
+        tagStart = read["start"]
+        tagReadID = read["readID"]
+        if read.has_key("sense"):
+            tagSense = read["sense"]
+            ignoreSense = False
+
+        index += 1
+        if index % 100000 == 0:
+            print "read %d" % index,
+
+        stopPoint = tagStart + readlen
+        if startFeature < 0:
+            startFeature = 0
+
+        for (start, stop, gid, sense, ftype) in featList[startFeature:]:
+            if tagStart > stop:
+                startFeature += 1
+                continue
+
+            if start > stopPoint:
+                startFeature -= 100
+                break
+
+            if not ignoreSense:
+                if sense == "R":
+                    sense = "-"
+                else:
+                    sense = "+"
+
+            if start <= tagStart <= stop and (ignoreSense or tagSense == sense):
+                readGidList.append((tagReadID, gid))
+                stopPoint = stop
+
+    return readGidList, index
+
+
+def writeCountsToFile(outFilename, countFilename, allGIDs, genome, gidReadDict, read2GidDict, doVerbose=False, doCache=False):
+
+    uniqueCountDict = {}
+    uniquecounts = open(countFilename)
+    for line in uniquecounts:
+        fields = line.strip().split()
+        # add a pseudo-count here to ease calculations below
+        #TODO: figure out why this was done in prior implementation...
+        uniqueCountDict[fields[0]] = float(fields[-1]) + 1
+
+    uniquecounts.close()
+
+    genomeName = genome.genome
+    geneinfoDict = getGeneInfoDict(genomeName, cache=doCache)
+    geneannotDict = genome.allAnnotInfo()
+    outfile = open(outFilename, "w")
+    for gid in allGIDs:
+        symbol = getGeneSymbol(gid, genomeName, geneinfoDict, geneannotDict)
+        tagCount = getTagCount(uniqueCountDict, gid, gidReadDict, read2GidDict)
         if doVerbose:
             print "%s %s %f" % (gid, symbol, tagCount)
 
@@ -258,8 +225,49 @@ def geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, ignoreSense=
 
     outfile.close()
 
-    if doCache:
-        uncacheGeneDB(genome)
+
+def getGeneSymbol(gid, genomeName, geneinfoDict, geneannotDict):
+    if "FAR" not in gid:
+        symbol = "LOC%s" % gid
+        geneinfo = ""
+        try:
+            geneinfo = geneinfoDict[gid]
+            if genomeName == "celegans":
+                symbol = geneinfo[0][1]
+            else:
+                symbol = geneinfo[0][0]
+        except (KeyError, IndexError):
+            try:
+                symbol = geneannotDict[(genomeName, gid)][0]
+            except (KeyError, IndexError):
+                symbol = "LOC%s" % gid
+    else:
+        symbol = gid
+
+    return symbol
+
+
+def getTagCount(uniqueCountDict, gid, gidReadDict, read2GidDict):
+    tagCount = 0.
+    for readID in gidReadDict[gid]:
+        try:
+            tagValue = uniqueCountDict[gid]
+        except KeyError:
+            tagValue = 1
+
+        tagDenom = 0.
+        for relatedGID in read2GidDict[readID]:
+            try:
+                tagDenom += uniqueCountDict[relatedGID]
+            except KeyError:
+                tagDenom += 1
+
+        try:
+            tagCount += tagValue / tagDenom
+        except ZeroDivisionError:
+            pass
+
+    return tagCount
 
 
 if __name__ == "__main__":