convert standard analysis pipelines to use bam format natively
[erange.git] / rnafarPairs.py
index 4d70c4965ef77c2eea4f646e58526db83682ffa6..2c5aff43ced96fe650d5cc52562cb998033b4d04 100755 (executable)
@@ -1,9 +1,3 @@
-#
-#  RNAFARpairs.py
-#  ENRAGE
-#
-#  Created by Ali Mortazavi on 11/2/08.
-#
 """ usage: python rnafarpairs.py genome goodfile rdsfile outfile [options]
            looks at all chromosomes simultaneously: is both slow and takes up large amount of RAM
 """
@@ -16,8 +10,8 @@ except:
 import sys
 import time
 import optparse
-import ReadDataset
-from commoncode import getGeneInfoDict, getGeneAnnotDict, getConfigParser, getConfigIntOption, getConfigBoolOption
+import pysam
+from commoncode import getGeneInfoDict, getGeneAnnotDict, getConfigParser, getConfigIntOption, getConfigBoolOption, isSpliceEntry
 
 
 def main(argv=None):
@@ -36,10 +30,12 @@ def main(argv=None):
 
     genome = args[0]
     goodfilename = args[1]
-    rdsfile = args[2]
+    bamfilename = args[2]
     outfilename = args[3]
 
-    rnaFarPairs(genome, goodfilename, rdsfile, outfilename, options.doVerbose, options.doCache, options.maxDist)
+    bamfile = pysam.Samfile(bamfilename, "rb")
+
+    rnaFarPairs(genome, goodfilename, bamfile, outfilename, options.doVerbose, options.doCache, options.maxDist)
 
 
 def makeParser(usage=""):
@@ -62,16 +58,26 @@ def makeParser(usage=""):
     return parser
 
 
-def rnaFarPairs(genome, goodfilename, rdsfile, outfilename, doVerbose=False, doCache=False, maxDist=500000):
+def rnaFarPairs(genome, goodfilename, bamfile, outfilename, doVerbose=False, doCache=False, maxDist=500000):
+    """ map all candidate regions that have paired ends overlapping with known genes
+    """
+
+    chromosomeList = [chrom for chrom in bamfile.references if chrom != "chrM"]
+    regions = {}
+    for chromosome in chromosomeList:
+        regions[chromosome] = []
+
     goodDict = {}
     goodfile = open(goodfilename)
     for line in goodfile:
         fields = line.split()
-        goodDict[fields[0]] = line
+        label = fields[0]
+        start = int(fields[2])
+        stop = int(fields[3])
+        goodDict[label] = line
+        regions[chromosome].append((start, stop, label))
 
     goodfile.close()
-    RDS = ReadDataset.ReadDataset(rdsfile, verbose = True, cache=doCache)
-    chromosomeList = RDS.getChromosomes()
     if doVerbose:
         print time.ctime()
 
@@ -83,23 +89,18 @@ def rnaFarPairs(genome, goodfilename, rdsfile, outfilename, doVerbose=False, doC
     assigned = {}
     farConnected = {}
     for chromosome in chromosomeList:
-        if doNotProcessChromosome(chromosome):
-            continue
-
         print chromosome
-        uniqDict = RDS.getReadsDict(fullChrom=True, chrom=chromosome, noSense=True, withFlag=True, doUniqs=True, readIDDict=True)
+        regionList = regions[chromosome].sort()
+        uniqDict, pairCount = getUniqueReadIDFlags(bamfile, regionList, chromosome, maxDist)
         if doVerbose:
             print len(uniqDict), time.ctime()    
 
-        for readID in uniqDict:
-            readList = uniqDict[readID]
-            if len(readList) == 2:
-                total += 1
-                if processReads(readList[:2], maxDist):
-                    flags = (readList[0]["flag"], readList[1]["flag"])
-                    processed, distinctPairs = writeFarPairsToFile(flags, goodDict, genome, geneinfoDict, geneannotDict, outfile, assigned, farConnected)
-                    total += processed
-                    distinct += distinctPairs
+        total += pairCount
+        for readID, readList in uniqDict.items():
+            flags = (readList[0]["flag"], readList[1]["flag"])
+            processed, distinctPairs = writeFarPairsToFile(flags, goodDict, genome, geneinfoDict, geneannotDict, outfile, assigned, farConnected)
+            total += processed
+            distinct += distinctPairs
 
     entriesWritten = writeUnassignedEntriesToFile(farConnected, assigned, goodDict, outfile)
     distinct += entriesWritten
@@ -109,15 +110,55 @@ def rnaFarPairs(genome, goodfilename, rdsfile, outfilename, doVerbose=False, doC
     print time.ctime()
 
 
-def doNotProcessChromosome(chromosome):
-    return chromosome == "chrM"
+def getUniqueReadIDFlags(bamfile, regions, chromosome, maxDist):
+    """ Returns dictionary of readsIDs with each entry consisting of a list of dictionaries of read start and read flag.
+        Only returns unique non-spliced read pairs matching the criteria given in processReads().
+    """
+    start = 1
+    readDict = {}
+    for regionstart, regionstop, regionname in regions:
+        for alignedread in bamfile.fetch(chromosome, start, regionstop):
+            if alignedread.opt("NH") == 1 and not isSpliceEntry(alignedread.cigar):
+                if alignedread.pos >= regionstart:
+                    flag = regionname
+                else:
+                    flag = alignedread.opt("ZG")
 
+                try:
+                    readDict[alignedread.qname].append({"start": alignedread.pos, "flag": flag})
+                except KeyError:
+                    readDict[alignedread.qname] = [{"start": alignedread.pos, "flag": flag}]
 
-def processReads(reads, maxDist):
+        start = regionstop + 1
+
+    for alignedread in bamfile.fetch(chromosome, start):
+        if alignedread.opt("NH") == 1 and not isSpliceEntry(alignedread.cigar):
+            flag = alignedread.opt("ZG")
+
+            try:
+                readDict[alignedread.qname].append({"start": alignedread.pos, "flag": flag})
+            except KeyError:
+                readDict[alignedread.qname] = [{"start": alignedread.pos, "flag": flag}]
+
+    pairCount = len(readDict.keys())
+    for readID, readList in readDict.items():
+        if len(readList) != 2:
+            del readDict[readID]
+            pairCount -= 1
+        elif not processReads(readList, maxDist):
+            del readDict[readID]
+
+    return readDict, pairCount
+
+
+def processReads(reads, maxDist=500000):
+    """ For a pair of readID's to be acceptable:
+            - flags must be different
+            - neither flag can be 'NM'
+            - the read starts have to be within maxDist
+    """
     process = False
-    start1 = reads[0]["start"]
-    start2 = reads[1]["start"]
-    dist = abs(start1 - start2)
+    dist = abs(reads[0]["start"] - reads[1]["start"])
     flag1 = reads[0]["flag"]
     flag2 = reads[1]["flag"]
 
@@ -128,6 +169,11 @@ def processReads(reads, maxDist):
 
 
 def writeFarPairsToFile(flags, goodDict, genome, geneInfoDict, geneAnnotDict, outfile, assigned, farConnected):
+    """ Writes out the region information along with symbol and geneID for paired reads where one read
+        of the pair is in a known gene and the other is in a new region.  If both reads in the pair are
+        in new regions the region is added to farConnected.  No action is taken if both reads in the
+        pair are in known genes.
+    """
     flag1, flag2 = flags
     total = 0
     distinct = 0
@@ -147,7 +193,7 @@ def writeFarPairsToFile(flags, goodDict, genome, geneInfoDict, geneAnnotDict, ou
         except KeyError:
             farConnected[geneID] = [farFlag]
     elif read1IsGood or read2IsGood:
-        total += 1
+        total = 1
         if read2IsGood:
             farFlag = flag2
             geneID = flag1
@@ -155,30 +201,35 @@ def writeFarPairsToFile(flags, goodDict, genome, geneInfoDict, geneAnnotDict, ou
             farFlag = flag1
             geneID = flag2
 
-        try:
-            if genome == "dmelanogaster":
-                symbol = geneInfoDict["Dmel_%s" % geneID][0][0]
-            else:
-                symbol = geneInfoDict[geneID][0][0]
-        except (KeyError, IndexError):
-            try:
-                symbol = geneAnnotDict[(genome, geneID)][0]
-            except (KeyError, IndexError):
-                symbol = "LOC%s" % geneID
-
-        symbol = symbol.strip()
-        symbol = symbol.replace(" ","|")
-        symbol = symbol.replace("\t","|")
-
+        symbol = getGeneSymbol(genome, geneID, geneInfoDict, geneAnnotDict)
         if farFlag not in assigned:
             assigned[farFlag] = (symbol, geneID)
             print "%s %s %s" % (symbol, geneID, goodDict[farFlag].strip())
             outfile.write("%s %s %s" % (symbol, geneID, goodDict[farFlag]))
-            distinct += 1
+            distinct = 1
 
     return total, distinct
 
 
+def getGeneSymbol(genome, geneID, geneInfoDict, geneAnnotDict):
+    try:
+        if genome == "dmelanogaster":
+            symbol = geneInfoDict["Dmel_%s" % geneID][0][0]
+        else:
+            symbol = geneInfoDict[geneID][0][0]
+    except (KeyError, IndexError):
+        try:
+            symbol = geneAnnotDict[(genome, geneID)][0]
+        except (KeyError, IndexError):
+            symbol = "LOC%s" % geneID
+
+    symbol = symbol.strip()
+    symbol = symbol.replace(" ","|")
+    symbol = symbol.replace("\t","|")
+
+    return symbol
+
+
 def writeUnassignedEntriesToFile(farConnected, assigned, goodDict, outfile):
     total, written = writeUnassignedPairsToFile(farConnected, assigned, goodDict, outfile)
     writeUnassignedGoodReadsToFile(total, goodDict, assigned, outfile)