first pass cleanup of cistematic/genomes; change bamPreprocessing
[erange.git] / findall.py
1 """
2     usage: python $ERANGEPATH/findall.py label samplebamfile regionoutfile
3            [--control controlbamfile] [--minimum minHits] [--ratio minRatio]
4            [--spacing maxSpacing] [--listPeak] [--shift #bp | learn] [--learnFold num]
5            [--noshift] [--autoshift] [--reportshift] [--nomulti] [--minPlus fraction]
6            [--maxPlus fraction] [--leftPlus fraction] [--minPeak RPM] [--raw]
7            [--revbackground] [--pvalue self|back|none] [--nodirectionality]
8            [--strandfilter plus/minus] [--trimvalue percent] [--notrim]
9            [--cache pages] [--log altlogfile] [--flag aflag] [--append] [--RNA]
10
11            where values in brackets are optional and label is an arbitrary string.
12
13            Use -ratio (default 4 fold) to set the minimum fold enrichment
14            over the control, -minimum (default 4) is the minimum number of reads
15            (RPM) within the region, and -spacing (default readlen) to set the maximum
16            distance between reads in the region. -listPeak lists the peak of the
17            region. Peaks mut be higher than -minPeak (default 0.5 RPM).
18            Pvalues are calculated from the sample (change with -pvalue),
19            unless the -revbackground flag and a control RDS file are provided.
20
21            By default, all numbers and parameters are on a reads per
22            million (RPM) basis. -raw will treat all settings, ratios and reported
23            numbers as raw counts rather than RPM. Use -notrim to turn off region
24            trimming and -trimvalue to control trimming (default 10% of peak signal)
25
26            The peak finder uses minimal directionality information that can
27            be turned off with -nodirectionality ; the fraction of + strand reads
28            required to be to the left of the peak (default 0.3) can be set with
29            -leftPlus ; -minPlus and -maxPlus change the minimum/maximum fraction
30            of plus reads in a region, which (defaults 0.25 and 0.75, respectively).
31
32            Use -shift to shift reads either by half the expected
33            fragment length (default 0 bp) or '-shift learn ' to learn the shift
34            based on the first chromosome. If you prefer to learn the shift
35            manually, use -autoshift to calculate a per-region shift value, which
36            can be reported using -reportshift. -strandfilter should only be used
37            when explicitely calling unshifted stranded peaks from non-ChIP-seq
38            data such as directional RNA-seq. regionoutfile is written over by
39            default unless given the -append flag.
40 """
41
42 try:
43     import psyco
44     psyco.full()
45 except:
46     pass
47
48 import sys
49 import math
50 import string
51 import optparse
52 import operator
53 import pysam
54 from commoncode import writeLog, findPeak, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption, isSpliceEntry
55 import Region
56
57
58 versionString = "findall: version 3.2.1"
59 print versionString
60
61 class RegionDirectionError(Exception):
62     pass
63             
64
65 class RegionFinder():
66     def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
67                  minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
68                  normalize=True, listPeak=False, reportshift=False, stringency=1.0, controlfile=None, doRevBackground=False):
69
70         self.statistics = {"index": 0,
71                            "total": 0,
72                            "mIndex": 0,
73                            "mTotal": 0,
74                            "failed": 0,
75                            "badRegionTrim": 0
76         }
77
78         self.regionLabel = label
79         self.rnaSettings = False
80         self.controlRDSsize = 1.0
81         self.sampleRDSsize = 1.0
82         self.minRatio = minRatio
83         self.minPeak = minPeak
84         self.leftPlusRatio = leftPlusRatio
85         self.stranded = "both"
86         if strandfilter == "plus":
87             self.stranded = "+"
88             minPlusRatio = 0.9
89             maxPlusRatio = 1.0
90         elif strandfilter == "minus":
91             self.stranded = "-"
92             minPlusRatio = 0.0
93             maxPlusRatio = 0.1
94
95         if minRatio < minPeak:
96             self.minPeak = minRatio
97
98         self.minPlusRatio = minPlusRatio
99         self.maxPlusRatio = maxPlusRatio
100         self.strandfilter = strandfilter
101         self.minHits = minHits
102         self.trimValue = trimValue
103         self.doTrim = doTrim
104         self.doDirectionality = doDirectionality
105
106         if self.doTrim:
107             self.trimString = string.join(["%2.1f" % (100. * self.trimValue), "%"], "")
108         else:
109             self.trimString = "none"
110
111         self.shiftValue = shiftValue
112         self.maxSpacing = maxSpacing
113         self.withFlag = withFlag
114         self.normalize = normalize
115         self.listPeak = listPeak
116         self.reportshift = reportshift
117         self.stringency = max(stringency, 1.0)
118         self.controlfile = controlfile
119         self.doControl = self.controlfile is not None
120         self.doPvalue = False
121         self.doRevBackground = doRevBackground
122
123
124     def useRNASettings(self, readlen):
125         self.rnaSettings = True
126         self.shiftValue = 0
127         self.doTrim = False
128         self.doDirectionality = False
129         self.maxSpacing = readlen
130
131
132     def getHeader(self):
133         if self.normalize:
134             countType = "RPM"
135         else:
136             countType = "COUNT"
137
138         headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
139
140         if self.doDirectionality:
141             headerFields.append("plus%\tleftPlus%")
142
143         if self.listPeak:
144             headerFields.append("peakPos\tpeakHeight")
145
146         if self.reportshift:
147             headerFields.append("readShift")
148
149         if self.doPvalue:
150             headerFields.append("pValue")
151
152         return string.join(headerFields, "\t")
153
154
155     def printSettings(self, ptype, useMulti, pValueType):
156         print
157         self.printStatusMessages(ptype, useMulti)
158         self.printOptionsSummary(useMulti, pValueType)
159
160
161     def printStatusMessages(self, ptype, useMulti):
162         if self.shiftValue == "learn":
163             print "Will try to learn shift"
164
165         if self.normalize:
166             print "Normalizing to RPM"
167
168         if self.doRevBackground:
169             print "Swapping IP and background to calculate FDR"
170
171         if ptype != "":
172             if ptype in ["NONE", "SELF"]:
173                 pass
174             elif ptype == "BACK":
175                 if self.doControl and self.doRevBackground:
176                     pass
177                 else:
178                     print "must have a control dataset and -revbackground for pValue type 'back'"
179             else:
180                 print "could not use pValue type : %s" % ptype
181
182         if self.withFlag != "":
183             print "restrict to flag = %s" % self.withFlag
184
185         if not useMulti:
186             print "using unique reads only"
187
188         if self.rnaSettings:
189             print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
190
191         if self.strandfilter == "plus":
192             print "only analyzing reads on the plus strand"
193         elif self.strandfilter == "minus":
194             print "only analyzing reads on the minus strand"
195
196
197     def printOptionsSummary(self, useMulti, pValueType):
198
199         print "\nenforceDirectionality=%s listPeak=%s nomulti=%s " % (self.doDirectionality, self.listPeak, not useMulti)
200         print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
201         try:
202             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
203         except:
204             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
205
206
207     def getAnalysisDescription(self, hitfile, useMulti, pValueType):
208
209         description = ["#ERANGE %s" % versionString]
210         if self.doControl:
211             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, self.controlfile, self.controlRDSsize))
212         else:
213             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
214
215         if self.withFlag != "":
216             description.append("#restrict to Flag = %s" % self.withFlag)
217
218         description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s" % (self.doDirectionality, self.listPeak, not useMulti))
219         description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
220         try:
221             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
222         except:
223             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
224
225         return string.join(description, "\n")
226
227
228     def getFooter(self, bestShift):
229         index = self.statistics["index"]
230         mIndex = self.statistics["mIndex"]
231         footerLines = ["#stats:\t%.1f RPM in %d regions" % (self.statistics["total"], index)]
232         if self.doDirectionality:
233             footerLines.append("#\t\t%d additional regions failed directionality filter" % self.statistics["failed"])
234
235         if self.doRevBackground:
236             try:
237                 percent = min(100. * (float(mIndex)/index), 100.)
238             except ZeroDivisionError:
239                 percent = 0.
240
241             footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, self.statistics["mTotal"], percent))
242
243         if self.shiftValue == "auto" and self.reportshift:
244             
245             footerLines.append("#mode of shift values: %d" % bestShift)
246
247         if self.statistics["badRegionTrim"] > 0:
248             footerLines.append("#%d regions discarded due to trimming problems" % self.statistics["badRegionTrim"])
249
250         return string.join(footerLines, "\n")
251
252
253     def updateControlStatistics(self, peak, sumAll, peakScore):
254
255         plusRatio = float(peak.numPlus)/peak.numHits
256         if peakScore >= self.minPeak and self.minPlusRatio <= plusRatio <= self.maxPlusRatio:
257             if self.doDirectionality:
258                 if self.leftPlusRatio < peak.numLeftPlus / peak.numPlus:
259                     self.statistics["mIndex"] += 1
260                     self.statistics["mTotal"] += sumAll
261                 else:
262                     self.statistics["failed"] += 1
263             else:
264                 # we have a region, but didn't check for directionality
265                 self.statistics["mIndex"] += 1
266                 self.statistics["mTotal"] += sumAll
267
268
269 def usage():
270     print __doc__
271
272
273 def main(argv=None):
274     if not argv:
275         argv = sys.argv
276
277     parser = makeParser()
278     (options, args) = parser.parse_args(argv[1:])
279
280     if len(args) < 3:
281         usage()
282         sys.exit(2)
283
284     factor = args[0]
285     hitfile = args[1]
286     outfilename = args[2]
287
288     shiftValue = 0
289
290     if options.autoshift:
291         shiftValue = "auto"
292
293     if options.shift is not None:
294         try:
295             shiftValue = int(options.shift)
296         except ValueError:
297             if options.shift == "learn":
298                 shiftValue = "learn"
299
300     if options.noshift:
301         shiftValue = 0
302
303     if options.doAppend:
304         outputMode = "a"
305     else:
306         outputMode = "w"
307
308     regionFinder = RegionFinder(factor, minRatio=options.minRatio, minPeak=options.minPeak, minPlusRatio=options.minPlusRatio,
309                                 maxPlusRatio=options.maxPlusRatio, leftPlusRatio=options.leftPlusRatio, strandfilter=options.strandfilter,
310                                 minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
311                                 doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
312                                 withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
313                                 reportshift=options.reportshift, stringency=options.stringency, controlfile=options.controlfile,
314                                 doRevBackground=options.doRevBackground)
315
316     findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
317             options.ptype, options.useMulti, options.combine5p)
318
319
320 def makeParser():
321     usage = __doc__
322
323     parser = optparse.OptionParser(usage=usage)
324     parser.add_option("--control", dest="controlfile")
325     parser.add_option("--minimum", type="float", dest="minHits")
326     parser.add_option("--ratio", type="float", dest="minRatio")
327     parser.add_option("--spacing", type="int", dest="maxSpacing")
328     parser.add_option("--listPeak", action="store_true", dest="listPeak")
329     parser.add_option("--shift", dest="shift")
330     parser.add_option("--learnFold", type="float", dest="stringency")
331     parser.add_option("--noshift", action="store_true", dest="noShift")
332     parser.add_option("--autoshift", action="store_true", dest="autoshift")
333     parser.add_option("--reportshift", action="store_true", dest="reportshift")
334     parser.add_option("--nomulti", action="store_false", dest="useMulti")
335     parser.add_option("--minPlus", type="float", dest="minPlusRatio")
336     parser.add_option("--maxPlus", type="float", dest="maxPlusRatio")
337     parser.add_option("--leftPlus", type="float", dest="leftPlusRatio")
338     parser.add_option("--minPeak", type="float", dest="minPeak")
339     parser.add_option("--raw", action="store_false", dest="normalize")
340     parser.add_option("--revbackground", action="store_true", dest="doRevBackground")
341     parser.add_option("--pvalue", dest="ptype")
342     parser.add_option("--nodirectionality", action="store_false", dest="doDirectionality")
343     parser.add_option("--strandfilter", dest="strandfilter")
344     parser.add_option("--trimvalue", type="float", dest="trimValue")
345     parser.add_option("--notrim", action="store_false", dest="doTrim")
346     parser.add_option("--cache", type="int", dest="cachePages")
347     parser.add_option("--log", dest="logfilename")
348     parser.add_option("--flag", dest="withFlag")
349     parser.add_option("--append", action="store_true", dest="doAppend")
350     parser.add_option("--RNA", action="store_true", dest="rnaSettings")
351     parser.add_option("--combine5p", action="store_true", dest="combine5p")
352
353     configParser = getConfigParser()
354     section = "findall"
355     minHits = getConfigFloatOption(configParser, section, "minHits", 4.0)
356     minRatio = getConfigFloatOption(configParser, section, "minRatio", 4.0)
357     maxSpacing = getConfigIntOption(configParser, section, "maxSpacing", 50)
358     listPeak = getConfigBoolOption(configParser, section, "listPeak", False)
359     shift = getConfigOption(configParser, section, "shift", None)
360     stringency = getConfigFloatOption(configParser, section, "stringency", 4.0)
361     noshift = getConfigBoolOption(configParser, section, "noshift", False)
362     autoshift = getConfigBoolOption(configParser, section, "autoshift", False)
363     reportshift = getConfigBoolOption(configParser, section, "reportshift", False)
364     minPlusRatio = getConfigFloatOption(configParser, section, "minPlusRatio", 0.25)
365     maxPlusRatio = getConfigFloatOption(configParser, section, "maxPlusRatio", 0.75)
366     leftPlusRatio = getConfigFloatOption(configParser, section, "leftPlusRatio", 0.3)
367     minPeak = getConfigFloatOption(configParser, section, "minPeak", 0.5)
368     normalize = getConfigBoolOption(configParser, section, "normalize", True)
369     logfilename = getConfigOption(configParser, section, "logfilename", "findall.log")
370     withFlag = getConfigOption(configParser, section, "withFlag", "")
371     doDirectionality = getConfigBoolOption(configParser, section, "doDirectionality", True)
372     trimValue = getConfigFloatOption(configParser, section, "trimValue", 0.1)
373     doTrim = getConfigBoolOption(configParser, section, "doTrim", True)
374     doAppend = getConfigBoolOption(configParser, section, "doAppend", False)
375     rnaSettings = getConfigBoolOption(configParser, section, "rnaSettings", False)
376     cachePages = getConfigOption(configParser, section, "cachePages", None)
377     ptype = getConfigOption(configParser, section, "ptype", "")
378     controlfile = getConfigOption(configParser, section, "controlfile", None)
379     doRevBackground = getConfigBoolOption(configParser, section, "doRevBackground", False)
380     useMulti = getConfigBoolOption(configParser, section, "useMulti", True)
381     strandfilter = getConfigOption(configParser, section, "strandfilter", "")
382     combine5p = getConfigBoolOption(configParser, section, "combine5p", False)
383
384     parser.set_defaults(minHits=minHits, minRatio=minRatio, maxSpacing=maxSpacing, listPeak=listPeak, shift=shift,
385                         stringency=stringency, noshift=noshift, autoshift=autoshift, reportshift=reportshift,
386                         minPlusRatio=minPlusRatio, maxPlusRatio=maxPlusRatio, leftPlusRatio=leftPlusRatio, minPeak=minPeak,
387                         normalize=normalize, logfilename=logfilename, withFlag=withFlag, doDirectionality=doDirectionality,
388                         trimValue=trimValue, doTrim=doTrim, doAppend=doAppend, rnaSettings=rnaSettings,
389                         cachePages=cachePages, ptype=ptype, controlfile=controlfile, doRevBackground=doRevBackground, useMulti=useMulti,
390                         strandfilter=strandfilter, combine5p=combine5p)
391
392     return parser
393
394
395 def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False,
396             ptype="", useMulti=True, combine5p=False):
397
398     writeLog(logfilename, versionString, string.join(sys.argv[1:]))
399     controlBAM = None
400     if regionFinder.doControl:
401         print "\ncontrol:" 
402         controlBAM = pysam.Samfile(regionFinder.controlfile, "rb")
403         regionFinder.controlRDSsize = int(getHeaderComment(controlBAM.header, "Total")) / 1000000.
404
405     print "\nsample:" 
406     sampleBAM = pysam.Samfile(hitfile, "rb")
407     regionFinder.sampleRDSsize = int(getHeaderComment(sampleBAM.header, "Total")) / 1000000.
408     pValueType = getPValueType(ptype, regionFinder.doControl, regionFinder.doRevBackground)
409     regionFinder.doPvalue = not pValueType == "none"
410     regionFinder.readlen = int(getHeaderComment(sampleBAM.header, "ReadLength"))
411     if rnaSettings:
412         regionFinder.useRNASettings(regionFinder.readlen)
413
414     regionFinder.printSettings(ptype, useMulti, pValueType)
415     outfile = open(outfilename, outputMode)
416     header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, pValueType)
417     shiftDict = {}
418     chromList = getChromosomeListToProcess(sampleBAM, controlBAM)
419     for chromosome in chromList:
420         #TODO: Really? Use first chr shift value for all of them
421         maxSampleCoord = getMaxCoordinate(sampleBAM, chromosome, doMulti=useMulti)
422         if regionFinder.shiftValue == "learn":
423             regionFinder.shiftValue = learnShift(regionFinder, sampleBAM, maxSampleCoord, chromosome, logfilename, outfilename, outfile, useMulti, controlBAM, combine5p)
424
425         allregions, outregions = findPeakRegions(regionFinder, sampleBAM, maxSampleCoord, chromosome, logfilename, outfilename, outfile, useMulti, controlBAM, combine5p)
426         if regionFinder.doRevBackground:
427             maxControlCoord = getMaxCoordinate(controlBAM, chromosome, doMulti=useMulti)
428             backregions = findBackgroundRegions(regionFinder, sampleBAM, controlBAM, maxControlCoord, chromosome, useMulti)
429         else:
430             backregions = []
431             pValueType = "self"
432
433         writeChromosomeResults(regionFinder, outregions, outfile, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
434
435     try:
436         bestShift = getBestShiftInDict(shiftDict)
437     except ValueError:
438         bestShift = 0
439
440     footer = regionFinder.getFooter(bestShift)
441     print footer
442     print >> outfile, footer
443     outfile.close()
444     writeLog(logfilename, versionString, outfilename + footer.replace("\n#"," | ")[:-1])
445
446
447 def getHeaderComment(bamHeader, commentKey):
448     for comment in bamHeader["CO"]:
449         fields = comment.split("\t")
450         if fields[0] == commentKey:
451             return fields[1]
452
453     raise KeyError
454
455
456 def getPValueType(ptype, doControl, doRevBackground):
457     pValueType = "self"
458     if ptype in ["NONE", "SELF", "BACK"]:
459         if ptype == "NONE":
460             pValueType = "none"
461         elif ptype == "SELF":
462             pValueType = "self"
463         elif ptype == "BACK":
464             if doControl and doRevBackground:
465                 pValueType = "back"
466     elif doRevBackground:
467         pValueType = "back"
468
469     return pValueType
470
471
472 def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, pValueType):
473     print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, pValueType)
474     header = regionFinder.getHeader()
475     print >> outfile, header
476
477     return header
478
479
480 def getChromosomeListToProcess(sampleBAM, controlBAM=None):
481     if controlBAM is not None:
482         chromosomeList = [chrom for chrom in sampleBAM.references if chrom in controlBAM.references and chrom != "chrM"]
483     else:
484         chromosomeList = [chrom for chrom in sampleBAM.references if chrom != "chrM"]
485
486     return chromosomeList
487
488
489 def findPeakRegions(regionFinder, sampleBAM, maxCoord, chromosome, logfilename, outfilename,
490                     outfile, useMulti, controlBAM, combine5p):
491
492     outregions = []
493     allregions = []
494     print "chromosome %s" % (chromosome)
495     previousHit = - 1 * regionFinder.maxSpacing
496     readStartPositions = [-1]
497     totalWeight = 0.0
498     uniqueReadCount = 0
499     reads = []
500     numStartsInRegion = 0
501
502     for alignedread in sampleBAM.fetch(chromosome):
503         if doNotProcessRead(alignedread, doMulti=useMulti, strand=regionFinder.stranded, combine5p=combine5p):
504             continue
505
506         pos = alignedread.pos
507         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
508             lastReadPos = readStartPositions[-1]
509             lastBasePosition = lastReadPos + regionFinder.readlen - 1
510             newRegionIndex = regionFinder.statistics["index"] + 1
511             if regionFinder.doDirectionality:
512                 region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel,
513                                                   numReads=totalWeight)
514             else:
515                 region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel, numReads=totalWeight)
516
517             if regionFinder.normalize:
518                 region.numReads /= regionFinder.sampleRDSsize
519
520             allregions.append(int(region.numReads))
521             regionLength = lastReadPos - region.start
522             if regionPassesCriteria(regionFinder, region.numReads, numStartsInRegion, regionLength):
523                 region.foldRatio = getFoldRatio(regionFinder, controlBAM, region.numReads, chromosome, region.start, lastReadPos, useMulti)
524
525                 if region.foldRatio >= regionFinder.minRatio:
526                     # first pass, with absolute numbers
527                     peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
528                     if regionFinder.doTrim:
529                         try:
530                             lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
531                         except IndexError:
532                             regionFinder.statistics["badRegionTrim"] += 1
533                             continue
534
535                         region.foldRatio = getFoldRatio(regionFinder, controlBAM, region.numReads, chromosome, region.start, lastReadPos, useMulti)
536
537                     try:
538                         bestPos = peak.topPos[0]
539                         peakScore = peak.smoothArray[bestPos]
540                         if regionFinder.normalize:
541                             peakScore /= regionFinder.sampleRDSsize
542                     except (IndexError, AttributeError, ZeroDivisionError):
543                         continue
544
545                     if regionFinder.listPeak:
546                         region.peakDescription = "%d\t%.1f" % (region.start + bestPos, peakScore)
547
548                     if useMulti:
549                         setMultireadPercentage(region, sampleBAM, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
550                                                regionFinder.normalize, regionFinder.doTrim)
551
552                     region.shift = peak.shift
553                     # check that we still pass threshold
554                     regionLength = lastReadPos - region.start
555                     plusRatio = float(peak.numPlus)/peak.numHits
556                     if regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
557                         try:
558                             updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
559                             regionFinder.statistics["index"] += 1
560                             outregions.append(region)
561                             regionFinder.statistics["total"] += region.numReads
562                         except RegionDirectionError:
563                             regionFinder.statistics["failed"] += 1
564
565             readStartPositions = []
566             totalWeight = 0.0
567             uniqueReadCount = 0
568             reads = []
569             numStartsInRegion = 0
570
571         if pos not in readStartPositions:
572             numStartsInRegion += 1
573
574         readStartPositions.append(pos)
575         weight = 1.0/alignedread.opt('NH')
576         totalWeight += weight
577         if weight == 1.0:
578             uniqueReadCount += 1
579
580         reads.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
581         previousHit = pos
582
583     return allregions, outregions
584
585
586 def getReadSense(read):
587     if read.is_reverse:
588         sense = "-"
589     else:
590         sense = "+"
591
592     return sense
593
594
595 def doNotProcessRead(read, doMulti=False, strand="both", combine5p=False):
596     if read.opt('NH') > 1 and not doMulti:
597         return True
598
599     if strand == "+" and read.is_reverse:
600         return True
601
602     if strand == "-" and not read.is_reverse:
603         return True
604         
605     return False
606
607
608 def getMaxCoordinate(samfile, chr, doMulti=False):
609     maxCoord = 0
610     for alignedread in samfile.fetch(chr):
611         if alignedread.opt('NH') > 1:
612             if doMulti:
613                 maxCoord = max(maxCoord, alignedread.pos)
614         else:
615             maxCoord = max(maxCoord, alignedread.pos)
616
617     return maxCoord
618
619
620 def findBackgroundRegions(regionFinder, sampleBAM, controlBAM, maxCoord, chromosome, useMulti):
621     #TODO: this is *almost* the same calculation - there are small yet important differences
622     print "calculating background..."
623     previousHit = - 1 * regionFinder.maxSpacing
624     currentHitList = [-1]
625     currentTotalWeight = 0.0
626     currentReadList = []
627     backregions = []
628     numStarts = 0
629     badRegion = False
630     for alignedread in controlBAM.fetch(chromosome):
631         if doNotProcessRead(alignedread, doMulti=useMulti):
632             continue
633
634         pos = alignedread.pos
635         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
636             lastReadPos = currentHitList[-1]
637             lastBasePosition = lastReadPos + regionFinder.readlen - 1
638             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
639             if regionFinder.normalize:
640                 region.numReads /= regionFinder.controlRDSsize
641
642             backregions.append(int(region.numReads))
643             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
644             regionLength = lastReadPos - region.start
645             if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
646                 numMock = 1. + countReadsInRegion(sampleBAM, chromosome, region.start, lastReadPos, countMulti=useMulti)
647                 if regionFinder.normalize:
648                     numMock /= regionFinder.sampleRDSsize
649
650                 foldRatio = region.numReads / numMock
651                 if foldRatio >= regionFinder.minRatio:
652                     # first pass, with absolute numbers
653                     peak = findPeak(currentReadList, region.start, lastReadPos - region.start, regionFinder.readlen, doWeight=True,
654                                     leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
655
656                     if regionFinder.doTrim:
657                         try:
658                             lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, 20., currentReadList, regionFinder.controlRDSsize)
659                         except IndexError:
660                             badRegion = True
661                             continue
662
663                         numMock = 1. + countReadsInRegion(sampleBAM, chromosome, region.start, lastReadPos, countMulti=useMulti)
664                         if regionFinder.normalize:
665                             numMock /= regionFinder.sampleRDSsize
666
667                         foldRatio = region.numReads / numMock
668
669                     try:
670                         bestPos = peak.topPos[0]
671                         peakScore = peak.smoothArray[bestPos]
672                     except IndexError:
673                         continue
674
675                     # normalize to RPM
676                     if regionFinder.normalize:
677                         peakScore /= regionFinder.controlRDSsize
678
679                     # check that we still pass threshold
680                     regionLength = lastReadPos - region.start
681                     if regionPassesCriteria(regionFinder, region.numReads, foldRatio, regionLength):
682                         regionFinder.updateControlStatistics(peak, region.numReads, peakScore)
683
684             currentHitList = []
685             currentTotalWeight = 0.0
686             currentReadList = []
687             numStarts = 0
688             if badRegion:
689                 badRegion = False
690                 regionFinder.statistics["badRegionTrim"] += 1
691
692         if pos not in currentHitList:
693             numStarts += 1
694
695         currentHitList.append(pos)
696         weight = 1.0/alignedread.opt('NH')
697         currentTotalWeight += weight
698         currentReadList.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
699         previousHit = pos
700
701     return backregions
702
703
704 def learnShift(regionFinder, sampleBAM, maxCoord, chromosome, logfilename, outfilename,
705                outfile, useMulti, controlBAM, combine5p):
706
707     print "learning shift.... will need at least 30 training sites"
708     stringency = regionFinder.stringency
709     previousHit = -1 * regionFinder.maxSpacing
710     positionList = [-1]
711     totalWeight = 0.0
712     readList = []
713     shiftDict = {}
714     count = 0
715     numStarts = 0
716     for alignedread in sampleBAM.fetch(chromosome):
717         if doNotProcessRead(alignedread, doMulti=useMulti, strand=regionFinder.stranded, combine5p=combine5p):
718             continue
719
720         pos = alignedread.pos
721         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
722             if regionFinder.normalize:
723                 totalWeight /= regionFinder.sampleRDSsize
724
725             regionStart = positionList[0]
726             regionStop = positionList[-1]
727             regionLength = regionStop - regionStart
728             if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
729                 foldRatio = getFoldRatio(regionFinder, controlBAM, totalWeight, chromosome, regionStart, regionStop, useMulti)
730                 if foldRatio >= regionFinder.minRatio:
731                     updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
732                     count += 1
733
734             positionList = []
735             totalWeight = 0.0
736             readList = []
737
738         if pos not in positionList:
739             numStarts += 1
740
741         positionList.append(pos)
742         weight = 1.0/alignedread.opt('NH')
743         totalWeight += weight
744         readList.append({"start": pos, "sense": getReadSense(alignedread), "weight": weight})
745         previousHit = pos
746
747     outline = "#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d\n#number of training examples: %d" % (stringency,
748                                                                                                                                stringency * regionFinder.minHits,
749                                                                                                                                stringency * regionFinder.minRatio,
750                                                                                                                                stringency * regionFinder.readlen,
751                                                                                                                                count)
752
753     print outline
754     writeLog(logfilename, versionString, outfilename + outline)
755     shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
756     outline = "#picked shiftValue to be %d" % shiftValue
757     print outline
758     print >> outfile, outline
759     writeLog(logfilename, versionString, outfilename + outline)
760
761     return shiftValue, shiftDict
762
763
764 def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
765     return abs(pos - previousHit) > maxSpacing or pos == maxCoord
766
767
768 def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringency=1):
769     minTotalReads = stringency * regionFinder.minHits
770     minNumReadStarts = stringency * regionFinder.minRatio
771     minRegionLength = stringency * regionFinder.readlen
772
773     return sumAll >= minTotalReads and numStarts >= minNumReadStarts and regionLength > minRegionLength
774
775
776 def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
777     bestPos = peak.topPos[0]
778     peakScore = peak.smoothArray[bestPos]
779     if regionFinder.normalize:
780         peakScore /= totalReadCount
781
782     minSignalThresh = trimValue * peakScore
783     start = findStartEdgePosition(peak, minSignalThresh)
784     regionEndPoint = regionStop - region.start - 1
785     stop = findStopEdgePosition(peak, regionEndPoint, minSignalThresh)
786
787     regionStop = region.start + stop
788     region.start += start
789
790     trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, regionFinder.readlen, doWeight=True,
791                            leftPlus=regionFinder.doDirectionality, shift=peak.shift)
792
793     peak.numPlus = trimmedPeak.numPlus
794     peak.numLeftPlus = trimmedPeak.numLeftPlus
795     peak.topPos = trimmedPeak.topPos
796     peak.smoothArray = trimmedPeak.smoothArray
797
798     region.numReads = trimmedPeak.numHits
799     if regionFinder.normalize:
800         region.numReads /= totalReadCount
801
802     region.stop = regionStop + regionFinder.readlen - 1
803                           
804     return regionStop
805
806
807 def findStartEdgePosition(peak, minSignalThresh):
808     start = 0
809     while not peakEdgeLocated(peak, start, minSignalThresh):
810         start += 1
811
812     return start
813
814
815 def findStopEdgePosition(peak, stop, minSignalThresh):
816     while not peakEdgeLocated(peak, stop, minSignalThresh):
817         stop -= 1
818
819     return stop
820
821
822 def peakEdgeLocated(peak, position, minSignalThresh):
823     return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
824
825
826 def getFoldRatio(regionFinder, controlBAM, sumAll, chromosome, regionStart, regionStop, useMulti):
827     """ Fold ratio calculated is total read weight over control
828     """
829     #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
830     if regionFinder.doControl:
831         numMock = 1. + countReadsInRegion(controlBAM, chromosome, regionStart, regionStop, countMulti=useMulti)
832         if regionFinder.normalize:
833             numMock /= regionFinder.controlRDSsize
834
835         foldRatio = sumAll / numMock
836     else:
837         foldRatio = regionFinder.minRatio
838
839     return foldRatio
840
841
842 def countReadsInRegion(bamfile, chr, start, end, uniqs=True, countMulti=False, countSplices=False):
843     count = 0.0
844     for alignedread in bamfile.fetch(chr, start, end):
845         if alignedread.opt('NH') > 1:
846             if countMulti:
847                 if isSpliceEntry(alignedread.cigar):
848                     if countSplices:
849                         count += 1.0/alignedread.opt('NH')
850                 else:
851                     count += 1.0/alignedread.opt('NH')
852         elif uniqs:
853             if isSpliceEntry(alignedread.cigar):
854                 if countSplices:
855                     count += 1.0
856             else:
857                 count += 1.0
858
859     return count
860
861 def updateShiftDict(shiftDict, readList, regionStart, regionLength, readlen):
862     peak = findPeak(readList, regionStart, regionLength, readlen, doWeight=True, shift="auto")
863     try:
864         shiftDict[peak.shift] += 1
865     except KeyError:
866         shiftDict[peak.shift] = 1
867
868
869 def getShiftValue(shiftDict, count, logfilename, outfilename):
870     if count < 30:
871         outline = "#too few training examples to pick a shiftValue - defaulting to 0\n#consider picking a lower minimum or threshold"
872         print outline
873         writeLog(logfilename, versionString, outfilename + outline)
874         shiftValue = 0
875     else:
876         shiftValue = getBestShiftInDict(shiftDict)
877         print shiftDict
878
879     return shiftValue
880
881
882 def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
883               peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
884               numPlus, plusRatio):
885
886     if doDirectionality:
887         if leftPlusRatio < numLeft / numPlus:
888             plusP = plusRatio * 100.
889             leftP = 100. * numLeft / numPlus
890             # we have a region that passes all criteria
891             region = Region.DirectionalRegion(regionStart, regionStop,
892                                               factor, index, chromosome, sumAll,
893                                               foldRatio, multiP, plusP, leftP,
894                                               peakDescription, shift)
895
896         else:
897             raise RegionDirectionError
898     else:
899         # we have a region, but didn't check for directionality
900         region = Region.Region(regionStart, regionStop, factor, index, chromosome,
901                                sumAll, foldRatio, multiP, peakDescription, shift)
902
903     return region
904
905
906 def setMultireadPercentage(region, sampleBAM, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
907     if doTrim:
908         sumMulti = 0.0
909         for alignedread in sampleBAM.fetch(chromosome, region.start, lastReadPos):
910             if alignedread.opt('NH') > 1:
911                 sumMulti += 1.0/alignedread.opt('NH')
912     else:
913         sumMulti = currentTotalWeight - currentUniqueCount
914
915     # normalize to RPM
916     if normalize:
917         sumMulti /= hitRDSsize
918
919     try:
920         multiP = 100. * (sumMulti / region.numReads)
921     except ZeroDivisionError:
922         return
923
924     region.multiP = min(multiP, 100.)
925
926
927 def regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
928     regionPasses = False
929     if regionPassesCriteria(regionFinder, region.numReads, region.foldRatio, regionLength):
930         #TODO: here is where the test dataset is failing
931         if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
932             regionPasses = True
933
934     return regionPasses
935
936
937 def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plusRatio):
938
939     if doDirectionality:
940         if leftPlusRatio < numLeft / numPlus:
941             region.plusP = plusRatio * 100.
942             region.leftP = 100. * numLeft / numPlus
943         else:
944             raise RegionDirectionError
945
946
947 def writeChromosomeResults(regionFinder, outregions, outfile, shiftDict,
948                            allregions, header, backregions=[], pValueType="none"):
949
950     print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
951     if regionFinder.doPvalue:
952         if pValueType == "self":
953             poissonmean = calculatePoissonMean(allregions)
954         else:
955             poissonmean = calculatePoissonMean(backregions)
956
957     print header
958     for region in outregions:
959         if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
960             try:
961                 shiftDict[region.shift] += 1
962             except KeyError:
963                 shiftDict[region.shift] = 1
964
965         outline = getRegionString(region, regionFinder.reportshift)
966
967         # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
968         if regionFinder.doPvalue:
969             pValue = calculatePValue(int(region.numReads), poissonmean)
970             outline += "\t%1.2g" % pValue
971
972         print outline
973         print >> outfile, outline
974
975
976 def calculatePoissonMean(dataList):
977     dataList.sort()
978     listSize = float(len(dataList))
979     try:
980         poissonmean = sum(dataList) / listSize
981     except ZeroDivisionError:
982         poissonmean = 0
983
984     print "Poisson n=%d, p=%f" % (listSize, poissonmean)
985
986     return poissonmean
987
988
989 def calculatePValue(sum, poissonmean):
990     pValue = math.exp(-poissonmean)
991     for i in xrange(sum):
992         pValue *= poissonmean
993         pValue /= i+1
994
995     return pValue
996
997
998 def getRegionString(region, reportShift):
999     if reportShift:
1000         outline = region.printRegionWithShift()
1001     else:
1002         outline = region.printRegion()
1003
1004     return outline
1005
1006
1007 def getBestShiftInDict(shiftDict):
1008     return max(shiftDict.iteritems(), key=operator.itemgetter(1))[0]
1009
1010
1011 if __name__ == "__main__":
1012     main(sys.argv)