rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
[erange.git] / findall.py
1 """
2     usage: python $ERANGEPATH/findall.py label samplerdsfile regionoutfile
3            [--control controlrdsfile] [--minimum minHits] [--ratio minRatio]
4            [--spacing maxSpacing] [--listPeak] [--shift #bp | learn] [--learnFold num]
5            [--noshift] [--autoshift] [--reportshift] [--nomulti] [--minPlus fraction]
6            [--maxPlus fraction] [--leftPlus fraction] [--minPeak RPM] [--raw]
7            [--revbackground] [--pvalue self|back|none] [--nodirectionality]
8            [--strandfilter plus/minus] [--trimvalue percent] [--notrim]
9            [--cache pages] [--log altlogfile] [--flag aflag] [--append] [--RNA]
10
11            where values in brackets are optional and label is an arbitrary string.
12
13            Use -ratio (default 4 fold) to set the minimum fold enrichment
14            over the control, -minimum (default 4) is the minimum number of reads
15            (RPM) within the region, and -spacing (default readlen) to set the maximum
16            distance between reads in the region. -listPeak lists the peak of the
17            region. Peaks mut be higher than -minPeak (default 0.5 RPM).
18            Pvalues are calculated from the sample (change with -pvalue),
19            unless the -revbackground flag and a control RDS file are provided.
20
21            By default, all numbers and parameters are on a reads per
22            million (RPM) basis. -raw will treat all settings, ratios and reported
23            numbers as raw counts rather than RPM. Use -notrim to turn off region
24            trimming and -trimvalue to control trimming (default 10% of peak signal)
25
26            The peak finder uses minimal directionality information that can
27            be turned off with -nodirectionality ; the fraction of + strand reads
28            required to be to the left of the peak (default 0.3) can be set with
29            -leftPlus ; -minPlus and -maxPlus change the minimum/maximum fraction
30            of plus reads in a region, which (defaults 0.25 and 0.75, respectively).
31
32            Use -shift to shift reads either by half the expected
33            fragment length (default 0 bp) or '-shift learn ' to learn the shift
34            based on the first chromosome. If you prefer to learn the shift
35            manually, use -autoshift to calculate a per-region shift value, which
36            can be reported using -reportshift. -strandfilter should only be used
37            when explicitely calling unshifted stranded peaks from non-ChIP-seq
38            data such as directional RNA-seq. regionoutfile is written over by
39            default unless given the -append flag.
40 """
41
42 try:
43     import psyco
44     psyco.full()
45 except:
46     pass
47
48 import sys
49 import math
50 import string
51 import optparse
52 import operator
53 from commoncode import writeLog, findPeak, getBestShiftForRegion, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption
54 import ReadDataset
55 import Region
56
57
58 versionString = "findall: version 3.2.1"
59 print versionString
60
61 class RegionDirectionError(Exception):
62     pass
63             
64
65 class RegionFinder():
66     def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
67                  minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
68                  normalize=True, listPeak=False, reportshift=False, stringency=1.0):
69
70         self.statistics = {"index": 0,
71                            "total": 0,
72                            "mIndex": 0,
73                            "mTotal": 0,
74                            "failed": 0,
75                            "badRegionTrim": 0
76         }
77
78         self.regionLabel = label
79         self.rnaSettings = False
80         self.controlRDSsize = 1
81         self.sampleRDSsize = 1
82         self.minRatio = minRatio
83         self.minPeak = minPeak
84         self.leftPlusRatio = leftPlusRatio
85         self.stranded = "both"
86         if strandfilter == "plus":
87             self.stranded = "+"
88             minPlusRatio = 0.9
89             maxPlusRatio = 1.0
90         elif strandfilter == "minus":
91             self.stranded = "-"
92             minPlusRatio = 0.0
93             maxPlusRatio = 0.1
94
95         if minRatio < minPeak:
96             self.minPeak = minRatio
97
98         self.minPlusRatio = minPlusRatio
99         self.maxPlusRatio = maxPlusRatio
100         self.strandfilter = strandfilter
101         self.minHits = minHits
102         self.trimValue = trimValue
103         self.doTrim = doTrim
104         self.doDirectionality = doDirectionality
105
106         if self.doTrim:
107             self.trimString = string.join(["%2.1f" % (100. * self.trimValue), "%"], "")
108         else:
109             self.trimString = "none"
110
111         self.shiftValue = shiftValue
112         self.maxSpacing = maxSpacing
113         self.withFlag = withFlag
114         self.normalize = normalize
115         self.listPeak = listPeak
116         self.reportshift = reportshift
117         self.stringency = max(stringency, 1.0)
118
119
120     def useRNASettings(self, readlen):
121         self.rnaSettings = True
122         self.shiftValue = 0
123         self.doTrim = False
124         self.doDirectionality = False
125         self.maxSpacing = readlen
126
127
128     def getHeader(self, doPvalue):
129         if self.normalize:
130             countType = "RPM"
131         else:
132             countType = "COUNT"
133
134         headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
135
136         if self.doDirectionality:
137             headerFields.append("plus%\tleftPlus%")
138
139         if self.listPeak:
140             headerFields.append("peakPos\tpeakHeight")
141
142         if self.reportshift:
143             headerFields.append("readShift")
144
145         if doPvalue:
146             headerFields.append("pValue")
147
148         return string.join(headerFields, "\t")
149
150
151     def printSettings(self, doRevBackground, ptype, doControl, useMulti, doCache, pValueType):
152         print
153         self.printStatusMessages(doRevBackground, ptype, doControl, useMulti)
154         self.printOptionsSummary(useMulti, doCache, pValueType)
155
156
157     def printStatusMessages(self, doRevBackground, ptype, doControl, useMulti):
158         if self.shiftValue == "learn":
159             print "Will try to learn shift"
160
161         if self.normalize:
162             print "Normalizing to RPM"
163
164         if doRevBackground:
165             print "Swapping IP and background to calculate FDR"
166
167         if ptype != "":
168             if ptype in ["NONE", "SELF"]:
169                 pass
170             elif ptype == "BACK":
171                 if doControl and doRevBackground:
172                     pass
173                 else:
174                     print "must have a control dataset and -revbackground for pValue type 'back'"
175             else:
176                 print "could not use pValue type : %s" % ptype
177
178         if self.withFlag != "":
179             print "restrict to flag = %s" % self.withFlag
180
181         if not useMulti:
182             print "using unique reads only"
183
184         if self.rnaSettings:
185             print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
186
187         if self.strandfilter == "plus":
188             print "only analyzing reads on the plus strand"
189         elif self.strandfilter == "minus":
190             print "only analyzing reads on the minus strand"
191
192
193     def printOptionsSummary(self, useMulti, doCache, pValueType):
194
195         print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, self.listPeak, not useMulti, doCache)
196         print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
197         try:
198             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
199         except:
200             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
201
202
203     def getAnalysisDescription(self, hitfile, useMulti, doCache, pValueType, controlfile, doControl):
204
205         description = ["#ERANGE %s" % versionString]
206         if doControl:
207             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, controlfile, self.controlRDSsize))
208         else:
209             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
210
211         if self.withFlag != "":
212             description.append("#restrict to Flag = %s" % self.withFlag)
213
214         description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, self.listPeak, not useMulti, doCache))
215         description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
216         try:
217             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
218         except:
219             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
220
221         return string.join(description, "\n")
222
223
224     def updateControlStatistics(self, peak, sumAll, peakScore):
225
226         plusRatio = float(peak.numPlus)/peak.numHits
227         if peakScore >= self.minPeak and self.minPlusRatio <= plusRatio <= self.maxPlusRatio:
228             if self.doDirectionality:
229                 if self.leftPlusRatio < peak.numLeft / peak.numPlus:
230                     self.statistics["mIndex"] += 1
231                     self.statistics["mTotal"] += sumAll
232                 else:
233                     self.statistics["failed"] += 1
234             else:
235                 # we have a region, but didn't check for directionality
236                 self.statistics["mIndex"] += 1
237                 self.statistics["mTotal"] += sumAll
238
239
240 def usage():
241     print __doc__
242
243
244 def main(argv=None):
245     if not argv:
246         argv = sys.argv
247
248     parser = makeParser()
249     (options, args) = parser.parse_args(argv[1:])
250
251     if len(args) < 3:
252         usage()
253         sys.exit(2)
254
255     factor = args[0]
256     hitfile = args[1]
257     outfilename = args[2]
258
259     shiftValue = 0
260
261     if options.autoshift:
262         shiftValue = "auto"
263
264     if options.shift is not None:
265         try:
266             shiftValue = int(options.shift)
267         except ValueError:
268             if options.shift == "learn":
269                 shiftValue = "learn"
270
271     if options.noshift:
272         shiftValue = 0
273
274     if options.doAppend:
275         outputMode = "a"
276     else:
277         outputMode = "w"
278
279     regionFinder = RegionFinder(factor, minRatio=options.minRatio, minPeak=options.minPeak, minPlusRatio=options.minPlusRatio,
280                                 maxPlusRatio=options.maxPlusRatio, leftPlusRatio=options.leftPlusRatio, strandfilter=options.strandfilter,
281                                 minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
282                                 doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
283                                 withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
284                                 reportshift=options.reportshift, stringency=options.stringency)
285
286     findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
287             options.cachePages, options.ptype, options.controlfile, options.doRevBackground,
288             options.useMulti, options.combine5p)
289
290
291 def makeParser():
292     usage = __doc__
293
294     parser = optparse.OptionParser(usage=usage)
295     parser.add_option("--control", dest="controlfile")
296     parser.add_option("--minimum", type="float", dest="minHits")
297     parser.add_option("--ratio", type="float", dest="minRatio")
298     parser.add_option("--spacing", type="int", dest="maxSpacing")
299     parser.add_option("--listPeak", action="store_true", dest="listPeak")
300     parser.add_option("--shift", dest="shift")
301     parser.add_option("--learnFold", type="float", dest="stringency")
302     parser.add_option("--noshift", action="store_true", dest="noShift")
303     parser.add_option("--autoshift", action="store_true", dest="autoshift")
304     parser.add_option("--reportshift", action="store_true", dest="reportshift")
305     parser.add_option("--nomulti", action="store_false", dest="useMulti")
306     parser.add_option("--minPlus", type="float", dest="minPlusRatio")
307     parser.add_option("--maxPlus", type="float", dest="maxPlusRatio")
308     parser.add_option("--leftPlus", type="float", dest="leftPlusRatio")
309     parser.add_option("--minPeak", type="float", dest="minPeak")
310     parser.add_option("--raw", action="store_false", dest="normalize")
311     parser.add_option("--revbackground", action="store_true", dest="doRevBackground")
312     parser.add_option("--pvalue", dest="ptype")
313     parser.add_option("--nodirectionality", action="store_false", dest="doDirectionality")
314     parser.add_option("--strandfilter", dest="strandfilter")
315     parser.add_option("--trimvalue", type="float", dest="trimValue")
316     parser.add_option("--notrim", action="store_false", dest="doTrim")
317     parser.add_option("--cache", type="int", dest="cachePages")
318     parser.add_option("--log", dest="logfilename")
319     parser.add_option("--flag", dest="withFlag")
320     parser.add_option("--append", action="store_true", dest="doAppend")
321     parser.add_option("--RNA", action="store_true", dest="rnaSettings")
322     parser.add_option("--combine5p", action="store_true", dest="combine5p")
323
324     configParser = getConfigParser()
325     section = "findall"
326     minHits = getConfigFloatOption(configParser, section, "minHits", 4.0)
327     minRatio = getConfigFloatOption(configParser, section, "minRatio", 4.0)
328     maxSpacing = getConfigIntOption(configParser, section, "maxSpacing", 50)
329     listPeak = getConfigBoolOption(configParser, section, "listPeak", False)
330     shift = getConfigOption(configParser, section, "shift", None)
331     stringency = getConfigFloatOption(configParser, section, "stringency", 4.0)
332     noshift = getConfigBoolOption(configParser, section, "noshift", False)
333     autoshift = getConfigBoolOption(configParser, section, "autoshift", False)
334     reportshift = getConfigBoolOption(configParser, section, "reportshift", False)
335     minPlusRatio = getConfigFloatOption(configParser, section, "minPlusRatio", 0.25)
336     maxPlusRatio = getConfigFloatOption(configParser, section, "maxPlusRatio", 0.75)
337     leftPlusRatio = getConfigFloatOption(configParser, section, "leftPlusRatio", 0.3)
338     minPeak = getConfigFloatOption(configParser, section, "minPeak", 0.5)
339     normalize = getConfigBoolOption(configParser, section, "normalize", True)
340     logfilename = getConfigOption(configParser, section, "logfilename", "findall.log")
341     withFlag = getConfigOption(configParser, section, "withFlag", "")
342     doDirectionality = getConfigBoolOption(configParser, section, "doDirectionality", True)
343     trimValue = getConfigFloatOption(configParser, section, "trimValue", 0.1)
344     doTrim = getConfigBoolOption(configParser, section, "doTrim", True)
345     doAppend = getConfigBoolOption(configParser, section, "doAppend", False)
346     rnaSettings = getConfigBoolOption(configParser, section, "rnaSettings", False)
347     cachePages = getConfigOption(configParser, section, "cachePages", None)
348     ptype = getConfigOption(configParser, section, "ptype", "")
349     controlfile = getConfigOption(configParser, section, "controlfile", None)
350     doRevBackground = getConfigBoolOption(configParser, section, "doRevBackground", False)
351     useMulti = getConfigBoolOption(configParser, section, "useMulti", True)
352     strandfilter = getConfigOption(configParser, section, "strandfilter", "")
353     combine5p = getConfigBoolOption(configParser, section, "combine5p", False)
354
355     parser.set_defaults(minHits=minHits, minRatio=minRatio, maxSpacing=maxSpacing, listPeak=listPeak, shift=shift,
356                         stringency=stringency, noshift=noshift, autoshift=autoshift, reportshift=reportshift,
357                         minPlusRatio=minPlusRatio, maxPlusRatio=maxPlusRatio, leftPlusRatio=leftPlusRatio, minPeak=minPeak,
358                         normalize=normalize, logfilename=logfilename, withFlag=withFlag, doDirectionality=doDirectionality,
359                         trimValue=trimValue, doTrim=doTrim, doAppend=doAppend, rnaSettings=rnaSettings,
360                         cachePages=cachePages, ptype=ptype, controlfile=controlfile, doRevBackground=doRevBackground, useMulti=useMulti,
361                         strandfilter=strandfilter, combine5p=combine5p)
362
363     return parser
364
365
366 def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False, cachePages=None,
367             ptype="", controlfile=None, doRevBackground=False, useMulti=True, combine5p=False):
368
369     writeLog(logfilename, versionString, string.join(sys.argv[1:]))
370     doCache = cachePages is not None
371     controlRDS = None
372     doControl = controlfile is not None
373     if doControl:
374         print "\ncontrol:" 
375         controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
376         regionFinder.controlRDSsize = len(controlRDS) / 1000000.
377
378     print "\nsample:" 
379     hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
380     regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
381     pValueType = getPValueType(ptype, doControl, doRevBackground)
382     doPvalue = not pValueType == "none"
383     regionFinder.readlen = hitRDS.getReadSize()
384     if rnaSettings:
385         regionFinder.useRNASettings(regionFinder.readlen)
386
387     regionFinder.printSettings(doRevBackground, ptype, doControl, useMulti, doCache, pValueType)
388     outfile = open(outfilename, outputMode)
389     header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl)
390     shiftDict = {}
391     chromosomeList = getChromosomeListToProcess(hitRDS, controlRDS, doControl)
392     for chromosome in chromosomeList:
393         if regionFinder.shiftValue == "learn":
394             learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
395
396         allregions, outregions = findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
397         if doRevBackground:
398             backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti)
399             writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
400         else:
401             writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header)
402
403     footer = getFooter(regionFinder, shiftDict, doRevBackground)
404     print footer
405     print >> outfile, footer
406     outfile.close()
407     writeLog(logfilename, versionString, outfilename + footer.replace("\n#"," | ")[:-1])
408
409
410 def getPValueType(ptype, doControl, doRevBackground):
411     pValueType = "self"
412     if ptype in ["NONE", "SELF", "BACK"]:
413         if ptype == "NONE":
414             pValueType = "none"
415         elif ptype == "SELF":
416             pValueType = "self"
417         elif ptype == "BACK":
418             if doControl and doRevBackground:
419                 pValueType = "back"
420     elif doRevBackground:
421         pValueType = "back"
422
423     return pValueType
424
425
426 def openRDSFile(filename, cachePages=None, doCache=False):
427     rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
428     if cachePages > rds.getDefaultCacheSize():
429         rds.setDBcache(cachePages)
430
431     return rds
432
433
434 def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl):
435     print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, doCache, pValueType, controlfile, doControl)
436     header = regionFinder.getHeader(doPvalue)
437     print >> outfile, header
438
439     return header
440
441
442 def getChromosomeListToProcess(hitRDS, controlRDS=None, doControl=False):
443     hitChromList = hitRDS.getChromosomes()
444     if doControl:
445         controlChromList = controlRDS.getChromosomes()
446         chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
447     else:
448         chromosomeList = [chrom for chrom in hitChromList if chrom != "chrM"]
449
450     return chromosomeList
451
452
453 def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
454                     outfile, useMulti, doControl, controlRDS, combine5p):
455
456     outregions = []
457     allregions = []
458     print "chromosome %s" % (chromosome)
459     previousHit = - 1 * regionFinder.maxSpacing
460     readStartPositions = [-1]
461     totalWeight = 0
462     uniqueReadCount = 0
463     reads = []
464     numStarts = 0
465     badRegion = False
466     hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
467                                   strand=regionFinder.stranded, combine5p=combine5p)
468
469     maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
470     for read in hitDict[chromosome]:
471         pos = read["start"]
472         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
473             lastReadPos = readStartPositions[-1]
474             lastBasePosition = lastReadPos + regionFinder.readlen - 1
475             newRegionIndex = regionFinder.statistics["index"] + 1
476             if regionFinder.doDirectionality:
477                 region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel,
478                                                   numReads=totalWeight)
479             else:
480                 region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel, numReads=totalWeight)
481
482             if regionFinder.normalize:
483                 region.numReads /= regionFinder.sampleRDSsize
484
485             allregions.append(int(region.numReads))
486             regionLength = lastReadPos - region.start
487             if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
488                 region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
489
490                 if region.foldRatio >= regionFinder.minRatio:
491                     # first pass, with absolute numbers
492                     peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
493                     if regionFinder.doTrim:
494                         try:
495                             lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
496                         except IndexError:
497                             badRegion = True
498                             continue
499
500                         region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
501
502                     # just in case it changed, use latest data
503                     try:
504                         bestPos = peak.topPos[0]
505                         peakScore = peak.smoothArray[bestPos]
506                         if regionFinder.normalize:
507                             peakScore /= regionFinder.sampleRDSsize
508                     except:
509                         continue
510
511                     if regionFinder.listPeak:
512                         region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)
513
514                     if useMulti:
515                         setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
516                                                regionFinder.normalize, regionFinder.doTrim)
517
518                     region.shift = peak.shift
519                     # check that we still pass threshold
520                     regionLength = lastReadPos - region.start
521                     plusRatio = float(peak.numPlus)/peak.numHits
522                     if regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
523                         try:
524                             updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
525                             regionFinder.statistics["index"] += 1
526                             outregions.append(region)
527                             regionFinder.statistics["total"] += region.numReads
528                         except RegionDirectionError:
529                             regionFinder.statistics["failed"] += 1
530
531             readStartPositions = []
532             totalWeight = 0
533             uniqueReadCount = 0
534             reads = []
535             numStarts = 0
536             if badRegion:
537                 badRegion = False
538                 regionFinder.statistics["badRegionTrim"] += 1
539
540         if pos not in readStartPositions:
541             numStarts += 1
542
543         readStartPositions.append(pos)
544         weight = read["weight"]
545         totalWeight += weight
546         if weight == 1.0:
547             uniqueReadCount += 1
548
549         reads.append({"start": pos, "sense": read["sense"], "weight": weight})
550         previousHit = pos
551
552     return allregions, outregions
553
554
555 def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti):
556     #TODO: this is *almost* the same calculation - there are small yet important differences
557     print "calculating background..."
558     previousHit = - 1 * regionFinder.maxSpacing
559     currentHitList = [-1]
560     currentTotalWeight = 0
561     currentReadList = []
562     backregions = []
563     numStarts = 0
564     badRegion = False
565     hitDict = controlRDS.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, findallOptimize=True)
566     maxCoord = controlRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
567     for read in hitDict[chromosome]:
568         pos = read["start"]
569         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
570             lastReadPos = currentHitList[-1]
571             lastBasePosition = lastReadPos + regionFinder.readlen - 1
572             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
573             if regionFinder.normalize:
574                 region.numReads /= regionFinder.controlRDSsize
575
576             backregions.append(int(region.numReads))
577             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
578             regionLength = lastReadPos - region.start
579             if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
580                 numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
581                 if regionFinder.normalize:
582                     numMock /= regionFinder.sampleRDSsize
583
584                 foldRatio = region.numReads / numMock
585                 if foldRatio >= regionFinder.minRatio:
586                     # first pass, with absolute numbers
587                     peak = findPeak(currentReadList, region.start, lastReadPos - region.start, regionFinder.readlen, doWeight=True,
588                                     leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
589
590                     if regionFinder.doTrim:
591                         try:
592                             lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, 20., currentReadList, regionFinder.controlRDSsize)
593                         except IndexError:
594                             badRegion = True
595                             continue
596
597                         numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
598                         if regionFinder.normalize:
599                             numMock /= regionFinder.sampleRDSsize
600
601                         foldRatio = region.numReads / numMock
602
603                     # just in case it changed, use latest data
604                     try:
605                         bestPos = peak.topPos[0]
606                         peakScore = peak.smoothArray[bestPos]
607                     except IndexError:
608                         continue
609
610                     # normalize to RPM
611                     if regionFinder.normalize:
612                         peakScore /= regionFinder.controlRDSsize
613
614                     # check that we still pass threshold
615                     regionLength = lastReadPos - region.start
616                     if regionPassesCriteria(regionFinder, region.numReads, foldRatio, regionLength):
617                         regionFinder.updateControlStatistics(peak, region.numReads, peakScore)
618
619             currentHitList = []
620             currentTotalWeight = 0
621             currentReadList = []
622             numStarts = 0
623             if badRegion:
624                 badRegion = False
625                 regionFinder.statistics["badRegionTrim"] += 1
626
627         if pos not in currentHitList:
628             numStarts += 1
629
630         currentHitList.append(pos)
631         weight = read["weight"]
632         currentTotalWeight += weight
633         currentReadList.append({"start": pos, "sense": read["sense"], "weight": weight})
634         previousHit = pos
635
636     return backregions
637
638
639 def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
640                outfile, useMulti, doControl, controlRDS, combine5p):
641
642     hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
643                                   strand=regionFinder.stranded, combine5p=combine5p)
644
645     maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
646     print "learning shift.... will need at least 30 training sites"
647     stringency = regionFinder.stringency
648     previousHit = -1 * regionFinder.maxSpacing
649     positionList = [-1]
650     totalWeight = 0
651     readList = []
652     shiftDict = {}
653     count = 0
654     numStarts = 0
655     for read in hitDict[chromosome]:
656         pos = read["start"]
657         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
658             if regionFinder.normalize:
659                 totalWeight /= regionFinder.sampleRDSsize
660
661             regionStart = positionList[0]
662             regionStop = positionList[-1]
663             regionLength = regionStop - regionStart
664             if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
665                 foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl)
666                 if foldRatio >= regionFinder.minRatio:
667                     updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
668                     count += 1
669
670             positionList = []
671             totalWeight = 0
672             readList = []
673
674         if pos not in positionList:
675             numStarts += 1
676
677         positionList.append(pos)
678         weight = read["weight"]
679         totalWeight += weight
680         readList.append({"start": pos, "sense": read["sense"], "weight": weight})
681         previousHit = pos
682
683     outline = "#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d\n#number of training examples: %d" % (stringency,
684                                                                                                                                stringency * regionFinder.minHits,
685                                                                                                                                stringency * regionFinder.minRatio,
686                                                                                                                                stringency * regionFinder.readlen,
687                                                                                                                                count)
688
689     print outline
690     writeLog(logfilename, versionString, outfilename + outline)
691     regionFinder.shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
692     outline = "#picked shiftValue to be %d" % regionFinder.shiftValue
693     print outline
694     print >> outfile, outline
695     writeLog(logfilename, versionString, outfilename + outline)
696
697
698 def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
699     return abs(pos - previousHit) > maxSpacing or pos == maxCoord
700
701
702 def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringency=1):
703     return sumAll >= stringency * regionFinder.minHits and numStarts > stringency * regionFinder.minRatio and regionLength > stringency * regionFinder.readlen
704
705
706 def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
707     bestPos = peak.topPos[0]
708     peakScore = peak.smoothArray[bestPos]
709     if regionFinder.normalize:
710         peakScore /= totalReadCount
711
712     minSignalThresh = trimValue * peakScore
713     start = findStartEdgePosition(peak, minSignalThresh)
714     regionEndPoint = regionStop - region.start - 1
715     stop = findStopEdgePosition(peak, regionEndPoint, minSignalThresh)
716
717     regionStop = region.start + stop
718     region.start += start
719
720     trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, regionFinder.readlen, doWeight=True,
721                            leftPlus=regionFinder.doDirectionality, shift=peak.shift)
722
723     peak.numPlus = trimmedPeak.numPlus
724     peak.numLeftPlus = trimmedPeak.numLeftPlus
725     peak.topPos = trimmedPeak.topPos
726     peak.smoothArray = trimmedPeak.smoothArray
727
728     region.numReads = trimmedPeak.numHits
729     if regionFinder.normalize:
730         region.numReads /= totalReadCount
731
732     region.stop = regionStop + regionFinder.readlen - 1
733                           
734     return regionStop
735
736
737 def findStartEdgePosition(peak, minSignalThresh):
738     start = 0
739     while not peakEdgeLocated(peak, start, minSignalThresh):
740         start += 1
741
742     return start
743
744
745 def findStopEdgePosition(peak, stop, minSignalThresh):
746     while not peakEdgeLocated(peak, stop, minSignalThresh):
747         stop -= 1
748
749     return stop
750
751
752 def peakEdgeLocated(peak, position, minSignalThresh):
753     return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
754
755
756 def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl):
757     """ Fold ratio calculated is total read weight over control
758     """
759     #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
760     if doControl:
761         numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
762         if regionFinder.normalize:
763             numMock /= regionFinder.controlRDSsize
764
765         foldRatio = sumAll / numMock
766     else:
767         foldRatio = regionFinder.minRatio
768
769     return foldRatio
770
771
772 def updateShiftDict(shiftDict, readList, regionStart, regionLength, readlen):
773     peak = findPeak(readList, regionStart, regionLength, readlen, doWeight=True, shift="auto")
774     try:
775         shiftDict[peak.shift] += 1
776     except KeyError:
777         shiftDict[peak.shift] = 1
778
779
780 def getShiftValue(shiftDict, count, logfilename, outfilename):
781     if count < 30:
782         outline = "#too few training examples to pick a shiftValue - defaulting to 0\n#consider picking a lower minimum or threshold"
783         print outline
784         writeLog(logfilename, versionString, outfilename + outline)
785         shiftValue = 0
786     else:
787         shiftValue = getBestShiftInDict(shiftDict)
788         print shiftDict
789
790     return shiftValue
791
792
793 def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
794               peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
795               numPlus, plusRatio):
796
797     if doDirectionality:
798         if leftPlusRatio < numLeft / numPlus:
799             plusP = plusRatio * 100.
800             leftP = 100. * numLeft / numPlus
801             # we have a region that passes all criteria
802             region = Region.DirectionalRegion(regionStart, regionStop,
803                                               factor, index, chromosome, sumAll,
804                                               foldRatio, multiP, plusP, leftP,
805                                               peakDescription, shift)
806
807         else:
808             raise RegionDirectionError
809     else:
810         # we have a region, but didn't check for directionality
811         region = Region.Region(regionStart, regionStop, factor, index, chromosome,
812                                sumAll, foldRatio, multiP, peakDescription, shift)
813
814     return region
815
816
817 def setMultireadPercentage(region, hitRDS, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
818     if doTrim:
819         sumMulti = hitRDS.getMultiCount(chromosome, region.start, lastReadPos)
820     else:
821         sumMulti = currentTotalWeight - currentUniqueCount
822
823     # normalize to RPM
824     if normalize:
825         sumMulti /= hitRDSsize
826
827     try:
828         multiP = 100. * (sumMulti / region.numReads)
829     except ZeroDivisionError:
830         return
831
832     region.multiP = multiP
833
834
835 def regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
836     regionPasses = False
837     if regionPassesCriteria(regionFinder, region.numReads, region.foldRatio, regionLength):
838         if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
839             regionPasses = True
840
841     return regionPasses
842
843
844 def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plusRatio):
845
846     if doDirectionality:
847         if leftPlusRatio < numLeft / numPlus:
848             region.plusP = plusRatio * 100.
849             region.leftP = 100. * numLeft / numPlus
850         else:
851             raise RegionDirectionError
852
853
854 def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
855                                 allregions, header):
856
857     writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
858                            allregions, header, backregions=[], pValueType="self")
859
860
861 def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
862                            allregions, header, backregions=[], pValueType="none"):
863
864     print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
865     if doPvalue:
866         if pValueType == "self":
867             poissonmean = calculatePoissonMean(allregions)
868         else:
869             poissonmean = calculatePoissonMean(backregions)
870
871     print header
872     writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=regionFinder.reportshift, shiftDict=shiftDict)
873
874
875 def calculatePoissonMean(dataList):
876     dataList.sort()
877     listSize = float(len(dataList))
878     try:
879         poissonmean = sum(dataList) / listSize
880     except ZeroDivisionError:
881         poissonmean = 0
882
883     print "Poisson n=%d, p=%f" % (listSize, poissonmean)
884
885     return poissonmean
886
887
888 def writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=0, reportshift=False, shiftDict={}):
889     for region in outregions:
890         if shiftValue == "auto" and reportshift:
891             try:
892                 shiftDict[region.shift] += 1
893             except KeyError:
894                 shiftDict[region.shift] = 1
895
896         outline = getRegionString(region, reportshift)
897
898         # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
899         if doPvalue:
900             sumAll = int(region.numReads)
901             pValue = calculatePValue(sumAll, poissonmean)
902             outline += "\t%1.2g" % pValue
903
904         print outline
905         print >> outfile, outline
906
907
908 def calculatePValue(sum, poissonmean):
909     pValue = math.exp(-poissonmean)
910     for i in xrange(sum):
911         pValue *= poissonmean
912         pValue /= i+1
913
914     return pValue
915
916
917 def getRegionString(region, reportShift):
918     if reportShift:
919         outline = region.printRegionWithShift()
920     else:
921         outline = region.printRegion()
922
923     return outline
924
925
926 def getFooter(regionFinder, shiftDict, doRevBackground):
927     index = regionFinder.statistics["index"]
928     mIndex = regionFinder.statistics["mIndex"]
929     footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
930     if regionFinder.doDirectionality:
931         footerLines.append("#\t\t%d additional regions failed directionality filter" % regionFinder.statistics["failed"])
932
933     if doRevBackground:
934         try:
935             percent = min(100. * (float(mIndex)/index), 100.)
936         except ZeroDivisionError:
937             percent = 0.
938
939         footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))
940
941     if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
942         bestShift = getBestShiftInDict(shiftDict)
943         footerLines.append("#mode of shift values: %d" % bestShift)
944
945     if regionFinder.statistics["badRegionTrim"] > 0:
946         footerLines.append("#%d regions discarded due to trimming problems" % regionFinder.statistics["badRegionTrim"])
947
948     return string.join(footerLines, "\n")
949
950
951 def getBestShiftInDict(shiftDict):
952     return max(shiftDict.iteritems(), key=operator.itemgetter(1))[0]
953
954
955 if __name__ == "__main__":
956     main(sys.argv)