convert standard analysis pipelines to use bam format natively
[erange.git] / geneMrnaCounts.py
index 7b4a2cc819c30976692d9721ddb363756ebdf70a..10ccc837dca3e6b36f3c7a14d94d3ebd4938ccd4 100755 (executable)
@@ -6,12 +6,11 @@ except:
 
 import sys
 import optparse
-from commoncode import getFeaturesByChromDict, getConfigParser, getConfigOption, getConfigBoolOption
-import ReadDataset
+from commoncode import getFeaturesByChromDict, getConfigParser, getConfigOption, getConfigBoolOption, countThisRead
 from cistematic.genomes import Genome
 from cistematic.core.geneinfo import geneinfoDB
 
-print "geneMrnaCounts: version 5.2"
+print "geneMrnaCounts: version 6.0"
 
 
 def main(argv=None):
@@ -33,7 +32,7 @@ def main(argv=None):
 
     geneMrnaCounts(genomeName, hitfile, outfilename, options.trackStrand, options.doSplices,
                    options.doUniqs, options.doMulti, options.extendGenome, options.replaceModels,
-                   options.searchGID, options.countFeats, options.cachePages, options.markGID)
+                   options.searchGID, options.countFeats, options.cachePages)
 
 
 def getParser(usage):
@@ -47,7 +46,6 @@ def getParser(usage):
     parser.add_option("--searchGID", action="store_true", dest="searchGID")
     parser.add_option("--countfeatures", action="store_true", dest="countFeats")
     parser.add_option("--cache", type="int", dest="cachePages")
-    parser.add_option("--markGID", action="store_true", dest="markGID")
 
     configParser = getConfigParser()
     section = "geneMrnaCounts"
@@ -60,17 +58,16 @@ def getParser(usage):
     searchGID = getConfigBoolOption(configParser, section, "searchGID", False)
     countFeats = getConfigBoolOption(configParser, section, "countFeats", False)
     cachePages = getConfigOption(configParser, section, "cachePages", None)
-    markGID = getConfigBoolOption(configParser, section, "markGID", False)
 
     parser.set_defaults(trackStrand=trackStrand, doSplices=doSplices, doUniqs=doUniqs, doMulti=doMulti,
                         extendGenome=extendGenome, replaceModels=replaceModels, searchGID=searchGID,
-                        countFeats=countFeats, cachePages=cachePages, markGID=markGID)
+                        countFeats=countFeats, cachePages=cachePages)
 
     return parser
 
-def geneMrnaCounts(genomeName, hitfile, outfilename, trackStrand=False, doSplices=False,
+def geneMrnaCounts(genomeName, bamfile, outfilename, trackStrand=False, doSplices=False,
                    doUniqs=True, doMulti=False, extendGenome="", replaceModels=False,
-                   searchGID=False, countFeats=False, cachePages=None, markGID=False):
+                   searchGID=False, countFeats=False):
 
     if trackStrand:
         print "will track strandedness"
@@ -86,16 +83,6 @@ def geneMrnaCounts(genomeName, hitfile, outfilename, trackStrand=False, doSplice
     else:
         replaceModels = False
 
-    if cachePages is not None:
-        doCache = True
-    else:
-        cachePages = 100000
-        doCache = False
-
-    hitRDS = ReadDataset.ReadDataset(hitfile, verbose=True, cache=doCache)
-    if cachePages > hitRDS.getDefaultCacheSize():
-        hitRDS.setDBcache(cachePages)
-
     genome = Genome(genomeName, inRAM=True)
     if extendGenome != "":
         genome.extendFeatures(extendGenome, replace=replaceModels)
@@ -111,24 +98,15 @@ def geneMrnaCounts(genomeName, hitfile, outfilename, trackStrand=False, doSplice
     for gid in gidList:
         gidCount[gid] = 0
 
-    chromList = hitRDS.getChromosomes(fullChrom=False)
-    if len(chromList) == 0 and doSplices:
-        chromList = hitRDS.getChromosomes(table="splices", fullChrom=False)
-
-    if markGID:
-        print "Flagging all reads as NM"
-        hitRDS.setFlags("NM", uniqs=doUniqs, multi=doMulti, splices=doSplices)
-
-    for chrom in chromList:
+    chromosomeList = [chrom for chrom in bamfile.references if chrom != "chrM"]
+    for chrom in chromosomeList:
         if chrom not in featuresByChromDict:
             continue
 
         if countFeats:
             seenFeaturesByChromDict[chrom] = set([])
 
-        print "\nchr%s" % chrom
-        fullchrom = "chr%s" % chrom
-        regionList = []        
+        print "\nchr%s" % chrom      
         print "counting GIDs"
         for (start, stop, gid, featureSense, featureType) in featuresByChromDict[chrom]:
             try:
@@ -137,34 +115,31 @@ def geneMrnaCounts(genomeName, hitfile, outfilename, trackStrand=False, doSplice
                     if featureSense == "R":
                         checkSense = "-"
 
-                    regionData = (gid, fullchrom, start, stop, checkSense)
-                    count = hitRDS.getCounts(fullchrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices, sense=checkSense)
+                    count = getCounts(chrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices, sense=checkSense)
                 else:
-                    regionData = (gid, fullchrom, start, stop)
-                    count = hitRDS.getCounts(fullchrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices)
+                    count = getCounts(chrom, start, stop, uniqs=doUniqs, multi=doMulti, splices=doSplices)
 
                 gidCount[gid] += count
-                if markGID:
-                    regionList.append(regionData)
-
                 if countFeats:
                     seenFeaturesByChromDict[chrom].add((start, stop, gid, featureSense))
             except:
                 print "problem with %s - skipping" % gid
 
-        if markGID:
-            print "marking GIDs"
-            hitRDS.flagReads(regionList, uniqs=doUniqs, multi=doMulti, splices=doSplices, sense=doStranded)
-            print "finished marking"
-
     print " "
     if countFeats:
         numFeatures = countFeatures(seenFeaturesByChromDict)
         print "saw %d features" % numFeatures
 
-    writeOutputFile(outfilename, genome, gidList, gidCount, searchGID)
-    if markGID and doCache:
-        hitRDS.saveCacheDB(hitfile)
+    writeOutputFile(outfilename, genome, gidCount, searchGID)
+
+
+def getCounts(bamfile, chrom, start, stop, uniqs=True, multi=False, splices=False, sense=''):
+    count = 0.0
+    for alignedread in bamfile.fetch(chrom, start, stop):
+        if countThisRead(alignedread, uniqs, multi, splices, sense):
+            count += 1.0/alignedread.opt('NH')
+
+    return count
 
 
 def countFeatures(seenFeaturesByChromDict):
@@ -178,18 +153,15 @@ def countFeatures(seenFeaturesByChromDict):
     return count
 
 
-def writeOutputFile(outfilename, genome, gidList, gidCount, searchGID):
+def writeOutputFile(outfilename, genome, gidCount, searchGID):
     geneAnnotDict = genome.allAnnotInfo()
     genomeName = genome.genome
     outfile = open(outfilename, "w")
     idb = geneinfoDB(cache=True)
     geneInfoDict = idb.getallGeneInfo(genomeName)
-    for gid in gidList:
+    for gid in gidCount:
         symbol = getGeneSymbol(gid, searchGID, geneInfoDict, idb, genomeName, geneAnnotDict)
-        if gid in gidCount:
-            outfile.write("%s\t%s\t%d\n" % (gid, symbol, gidCount[gid]))
-        else:
-            outfile.write("%s\t%s\t0\n" % (gid, symbol))
+        outfile.write("%s\t%s\t%d\n" % (gid, symbol, gidCount[gid]))
 
     outfile.close()