first pass cleanup of cistematic/genomes; change bamPreprocessing
[erange.git] / rnafarPairs.py
index d1baebd3e7e6e61278c4c141a84fe2396e323f28..2c5aff43ced96fe650d5cc52562cb998033b4d04 100755 (executable)
@@ -1,9 +1,3 @@
-#
-#  RNAFARpairs.py
-#  ENRAGE
-#
-#  Created by Ali Mortazavi on 11/2/08.
-#
 """ usage: python rnafarpairs.py genome goodfile rdsfile outfile [options]
            looks at all chromosomes simultaneously: is both slow and takes up large amount of RAM
 """
@@ -13,26 +7,21 @@ try:
 except:
     pass
 
-import sys, time, optparse
-from commoncode import readDataset
-from cistematic.core.geneinfo import geneinfoDB
-from cistematic.genomes import Genome
+import sys
+import time
+import optparse
+import pysam
+from commoncode import getGeneInfoDict, getGeneAnnotDict, getConfigParser, getConfigIntOption, getConfigBoolOption, isSpliceEntry
+
 
 def main(argv=None):
     if not argv:
         argv = sys.argv
 
-    print "%prog: version 3.6"
+    print "rnafarPairs: version 3.7"
     usage = "usage: python %prog genome goodfile rdsfile outfile [options]"
 
-    parser = optparse.OptionParser(usage=usage)
-    parser.add_option("--verbose", action="store_true", dest="doVerbose",
-                      help="verbose output")
-    parser.add_option("--cache", action="store_true", dest="doCache",
-                      help="use cache")
-    parser.add_option("--maxDist", type="int", dest="maxDist",
-                      help="maximum distance")
-    parser.set_defaults(doVerbose=False, doCache=False, maxDist=500000)
+    parser = makeParser(usage)
     (options, args) = parser.parse_args(argv[1:])
 
     if len(args) < 4:
@@ -41,129 +30,240 @@ def main(argv=None):
 
     genome = args[0]
     goodfilename = args[1]
-    rdsfile = args[2]
+    bamfilename = args[2]
     outfilename = args[3]
 
-    rnaFarPairs(genome, goodfilename, rdsfile, outfilename, options.doVerbose, options.doCache, options.maxDist)
-    
+    bamfile = pysam.Samfile(bamfilename, "rb")
+
+    rnaFarPairs(genome, goodfilename, bamfile, outfilename, options.doVerbose, options.doCache, options.maxDist)
+
+
+def makeParser(usage=""):
+    parser = optparse.OptionParser(usage=usage)
+    parser.add_option("--verbose", action="store_true", dest="doVerbose",
+                      help="verbose output")
+    parser.add_option("--cache", action="store_true", dest="doCache",
+                      help="use cache")
+    parser.add_option("--maxDist", type="int", dest="maxDist",
+                      help="maximum distance")
+
+    configParser = getConfigParser()
+    section = "rnafarPairs"
+    doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
+    doCache = getConfigBoolOption(configParser, section, "doCache", False)
+    maxDist = getConfigIntOption(configParser, section, "maxDist", 500000)
+
+    parser.set_defaults(doVerbose=doVerbose, doCache=doCache, maxDist=maxDist)
+
+    return parser
+
+
+def rnaFarPairs(genome, goodfilename, bamfile, outfilename, doVerbose=False, doCache=False, maxDist=500000):
+    """ map all candidate regions that have paired ends overlapping with known genes
+    """
+
+    chromosomeList = [chrom for chrom in bamfile.references if chrom != "chrM"]
+    regions = {}
+    for chromosome in chromosomeList:
+        regions[chromosome] = []
 
-def rnaFarPairs(genome, goodfilename, rdsfile, outfilename, doVerbose=False, doCache=False, maxDist=500000):
     goodDict = {}
     goodfile = open(goodfilename)
     for line in goodfile:
         fields = line.split()
-        goodDict[fields[0]] = line
-
-    RDS = readDataset(rdsfile, verbose = True, cache=doCache)
-    rdsChromList = RDS.getChromosomes()
+        label = fields[0]
+        start = int(fields[2])
+        stop = int(fields[3])
+        goodDict[label] = line
+        regions[chromosome].append((start, stop, label))
 
+    goodfile.close()
     if doVerbose:
         print time.ctime()
 
     distinct = 0
     total = 0
     outfile = open(outfilename,"w")
-
-    idb = geneinfoDB()
-    if genome == "dmelanogaster":
-        geneinfoDict = idb.getallGeneInfo(genome, infoKey="locus")
-    else:
-        geneinfoDict = idb.getallGeneInfo(genome)
-
-    hg = Genome(genome)
-    geneannotDict = hg.allAnnotInfo()
-
+    geneinfoDict = getGeneInfoDict(genome)
+    geneannotDict = getGeneAnnotDict(genome)
     assigned = {}
     farConnected = {}
-    for achrom in rdsChromList:
-        if achrom == "chrM":
-            continue
-
-        print achrom
-        uniqDict = RDS.getReadsDict(fullChrom=True, chrom=achrom, noSense=True, withFlag=True, withPairID=True, doUniqs=True, readIDDict=True)
+    for chromosome in chromosomeList:
+        print chromosome
+        regionList = regions[chromosome].sort()
+        uniqDict, pairCount = getUniqueReadIDFlags(bamfile, regionList, chromosome, maxDist)
         if doVerbose:
             print len(uniqDict), time.ctime()    
 
-        for readID in uniqDict:
-            readList = uniqDict[readID]
-            if len(readList) == 2:
-                total += 1
-                (start1, flag1, pair1) = readList[0]
-                (start2, flag2, pair2) = readList[1]
-
-                if flag1 != flag2:
-                    dist = abs(start1 - start2)
-                    if flag1 != "NM" and flag2 != "NM" and dist < maxDist:
-                        geneID = ""
-                        saw1 = False
-                        saw2 = False
-                        if flag1 in goodDict:
-                            geneID = flag2
-                            farFlag = flag1
-                            saw1 = True
-
-                        if flag2 in goodDict:
-                            geneID = flag1
-                            farFlag = flag2
-                            saw2 = True
-
-                        if saw1 or saw2:
-                            total += 1
-
-                        if saw1 and saw2:
-                            if flag1 < flag2:
-                                geneID = flag1
-                                farFlag = flag2
-                            else:
-                                geneID = flag2
-                                farFlag = flag1
-
-                            if geneID in farConnected:
-                                farConnected[geneID].append(farFlag)
-                            else:
-                                farConnected[geneID] = [farFlag]
-                        elif geneID != "":
-                            try:
-                                if genome == "dmelanogaster":
-                                    symbol = geneinfoDict["Dmel_" + geneID][0][0]
-                                else:
-                                    symbol = geneinfoDict[geneID][0][0]
-                            except:
-                                try:
-                                    symbol = geneannotDict[(genome, geneID)][0]
-                                except:
-                                    symbol = "LOC" + geneID
-
-                            symbol = symbol.strip()
-                            symbol = symbol.replace(" ","|")
-                            symbol = symbol.replace("\t","|")
-                            if farFlag not in assigned:
-                                assigned[farFlag] = (symbol, geneID)
-                                print "%s %s %s" % (symbol, geneID, goodDict[farFlag].strip())
-                                outfile.write("%s %s %s" % (symbol, geneID, goodDict[farFlag]))
-                                distinct += 1
-
-    farIndex = 0
+        total += pairCount
+        for readID, readList in uniqDict.items():
+            flags = (readList[0]["flag"], readList[1]["flag"])
+            processed, distinctPairs = writeFarPairsToFile(flags, goodDict, genome, geneinfoDict, geneannotDict, outfile, assigned, farConnected)
+            total += processed
+            distinct += distinctPairs
+
+    entriesWritten = writeUnassignedEntriesToFile(farConnected, assigned, goodDict, outfile)
+    distinct += entriesWritten
+    outfile.write("#distinct: %d\ttotal: %d\n" % (distinct, total))
+    outfile.close()
+    print "distinct: %d\ttotal: %d" % (distinct, total)
+    print time.ctime()
+
+
+def getUniqueReadIDFlags(bamfile, regions, chromosome, maxDist):
+    """ Returns dictionary of readsIDs with each entry consisting of a list of dictionaries of read start and read flag.
+        Only returns unique non-spliced read pairs matching the criteria given in processReads().
+    """
+    start = 1
+    readDict = {}
+    for regionstart, regionstop, regionname in regions:
+        for alignedread in bamfile.fetch(chromosome, start, regionstop):
+            if alignedread.opt("NH") == 1 and not isSpliceEntry(alignedread.cigar):
+                if alignedread.pos >= regionstart:
+                    flag = regionname
+                else:
+                    flag = alignedread.opt("ZG")
+
+                try:
+                    readDict[alignedread.qname].append({"start": alignedread.pos, "flag": flag})
+                except KeyError:
+                    readDict[alignedread.qname] = [{"start": alignedread.pos, "flag": flag}]
+
+        start = regionstop + 1
+
+    for alignedread in bamfile.fetch(chromosome, start):
+        if alignedread.opt("NH") == 1 and not isSpliceEntry(alignedread.cigar):
+            flag = alignedread.opt("ZG")
+
+            try:
+                readDict[alignedread.qname].append({"start": alignedread.pos, "flag": flag})
+            except KeyError:
+                readDict[alignedread.qname] = [{"start": alignedread.pos, "flag": flag}]
+
+    pairCount = len(readDict.keys())
+    for readID, readList in readDict.items():
+        if len(readList) != 2:
+            del readDict[readID]
+            pairCount -= 1
+        elif not processReads(readList, maxDist):
+            del readDict[readID]
+
+    return readDict, pairCount
+
+
+def processReads(reads, maxDist=500000):
+    """ For a pair of readID's to be acceptable:
+            - flags must be different
+            - neither flag can be 'NM'
+            - the read starts have to be within maxDist
+    """
+    process = False
+    dist = abs(reads[0]["start"] - reads[1]["start"])
+    flag1 = reads[0]["flag"]
+    flag2 = reads[1]["flag"]
+
+    if flag1 != flag2 and flag1 != "NM" and flag2 != "NM" and dist < maxDist:
+        process = True
+
+    return process
+
+
+def writeFarPairsToFile(flags, goodDict, genome, geneInfoDict, geneAnnotDict, outfile, assigned, farConnected):
+    """ Writes out the region information along with symbol and geneID for paired reads where one read
+        of the pair is in a known gene and the other is in a new region.  If both reads in the pair are
+        in new regions the region is added to farConnected.  No action is taken if both reads in the
+        pair are in known genes.
+    """
+    flag1, flag2 = flags
+    total = 0
+    distinct = 0
+    read1IsGood = flag1 in goodDict
+    read2IsGood = flag2 in goodDict
+
+    if read1IsGood and read2IsGood:
+        if flag1 < flag2:
+            geneID = flag1
+            farFlag = flag2
+        else:
+            geneID = flag2
+            farFlag = flag1
+
+        try:
+            farConnected[geneID].append(farFlag)
+        except KeyError:
+            farConnected[geneID] = [farFlag]
+    elif read1IsGood or read2IsGood:
+        total = 1
+        if read2IsGood:
+            farFlag = flag2
+            geneID = flag1
+        else:
+            farFlag = flag1
+            geneID = flag2
+
+        symbol = getGeneSymbol(genome, geneID, geneInfoDict, geneAnnotDict)
+        if farFlag not in assigned:
+            assigned[farFlag] = (symbol, geneID)
+            print "%s %s %s" % (symbol, geneID, goodDict[farFlag].strip())
+            outfile.write("%s %s %s" % (symbol, geneID, goodDict[farFlag]))
+            distinct = 1
+
+    return total, distinct
+
+
+def getGeneSymbol(genome, geneID, geneInfoDict, geneAnnotDict):
+    try:
+        if genome == "dmelanogaster":
+            symbol = geneInfoDict["Dmel_%s" % geneID][0][0]
+        else:
+            symbol = geneInfoDict[geneID][0][0]
+    except (KeyError, IndexError):
+        try:
+            symbol = geneAnnotDict[(genome, geneID)][0]
+        except (KeyError, IndexError):
+            symbol = "LOC%s" % geneID
+
+    symbol = symbol.strip()
+    symbol = symbol.replace(" ","|")
+    symbol = symbol.replace("\t","|")
+
+    return symbol
+
+
+def writeUnassignedEntriesToFile(farConnected, assigned, goodDict, outfile):
+    total, written = writeUnassignedPairsToFile(farConnected, assigned, goodDict, outfile)
+    writeUnassignedGoodReadsToFile(total, goodDict, assigned, outfile)
+
+    return written
+
+
+def writeUnassignedPairsToFile(farConnected, assigned, goodDict, outfile):
+    total = 0
+    written = 0
     for farFlag in farConnected:
         geneID = ""
         symbol = ""
         idList = [farFlag] + farConnected[farFlag]
-        for oneID in idList:
-            if oneID in assigned:
-                (symbol, geneID) = assigned[oneID]
+        for ID in idList:
+            if ID in assigned:
+                (symbol, geneID) = assigned[ID]
 
         if geneID == "":
-            farIndex += 1
-            symbol = "FAR%d" % farIndex
-            geneID = -1 * farIndex
+            total += 1
+            symbol = "FAR%d" % total
+            geneID = -1 * total
+
+        for ID in idList:
+            if ID not in assigned:
+                print "%s %s %s" % (symbol, geneID, goodDict[ID].strip())
+                outfile.write("%s %s %s" % (symbol, geneID, goodDict[ID]))
+                written += 1
+                assigned[ID] = (symbol, geneID)
+
+    return total, written
 
-        for oneID in idList:
-            if oneID not in assigned:
-                print "%s %s %s" % (symbol, geneID, goodDict[oneID].strip())
-                outfile.write("%s %s %s" % (symbol, geneID, goodDict[oneID]))
-                distinct += 1
-                assigned[oneID] = (symbol, geneID)
 
+def writeUnassignedGoodReadsToFile(farIndex, goodDict, assigned, outfile):
     for farFlag in goodDict:
         if farFlag not in assigned:
             farIndex += 1
@@ -171,10 +271,6 @@ def rnaFarPairs(genome, goodfilename, rdsfile, outfilename, doVerbose=False, doC
             print line.strip()
             outfile.write(line)
 
-    outfile.write("#distinct: %d\ttotal: %d\n" % (distinct, total))
-    outfile.close()
-    print "distinct: %d\ttotal: %d" % (distinct, total)
-    print time.ctime()
 
 if __name__ == "__main__":
     main(sys.argv)
\ No newline at end of file