development release: conversion of ReadDataset to use BAM files
[erange.git] / findall.py
1 """
2     usage: python $ERANGEPATH/findall.py label samplerdsfile regionoutfile
3            [--control controlrdsfile] [--minimum minHits] [--ratio minRatio]
4            [--spacing maxSpacing] [--listPeak] [--shift #bp | learn] [--learnFold num]
5            [--noshift] [--autoshift] [--reportshift] [--nomulti] [--minPlus fraction]
6            [--maxPlus fraction] [--leftPlus fraction] [--minPeak RPM] [--raw]
7            [--revbackground] [--pvalue self|back|none] [--nodirectionality]
8            [--strandfilter plus/minus] [--trimvalue percent] [--notrim]
9            [--cache pages] [--log altlogfile] [--flag aflag] [--append] [--RNA]
10
11            where values in brackets are optional and label is an arbitrary string.
12
13            Use -ratio (default 4 fold) to set the minimum fold enrichment
14            over the control, -minimum (default 4) is the minimum number of reads
15            (RPM) within the region, and -spacing (default readlen) to set the maximum
16            distance between reads in the region. -listPeak lists the peak of the
17            region. Peaks mut be higher than -minPeak (default 0.5 RPM).
18            Pvalues are calculated from the sample (change with -pvalue),
19            unless the -revbackground flag and a control RDS file are provided.
20
21            By default, all numbers and parameters are on a reads per
22            million (RPM) basis. -raw will treat all settings, ratios and reported
23            numbers as raw counts rather than RPM. Use -notrim to turn off region
24            trimming and -trimvalue to control trimming (default 10% of peak signal)
25
26            The peak finder uses minimal directionality information that can
27            be turned off with -nodirectionality ; the fraction of + strand reads
28            required to be to the left of the peak (default 0.3) can be set with
29            -leftPlus ; -minPlus and -maxPlus change the minimum/maximum fraction
30            of plus reads in a region, which (defaults 0.25 and 0.75, respectively).
31
32            Use -shift to shift reads either by half the expected
33            fragment length (default 0 bp) or '-shift learn ' to learn the shift
34            based on the first chromosome. If you prefer to learn the shift
35            manually, use -autoshift to calculate a per-region shift value, which
36            can be reported using -reportshift. -strandfilter should only be used
37            when explicitely calling unshifted stranded peaks from non-ChIP-seq
38            data such as directional RNA-seq. regionoutfile is written over by
39            default unless given the -append flag.
40 """
41
42 try:
43     import psyco
44     psyco.full()
45 except:
46     pass
47
48 import sys
49 import math
50 import string
51 import optparse
52 import operator
53 from commoncode import writeLog, findPeak, getBestShiftForRegion, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption
54 import ReadDataset
55 import Region
56
57
58 versionString = "findall: version 3.2.1"
59 print versionString
60
61 class RegionDirectionError(Exception):
62     pass
63             
64
65 class RegionFinder():
66     def __init__(self, label, minRatio=4.0, minPeak=0.5, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, strandfilter="",
67                  minHits=4.0, trimValue=0.1, doTrim=True, doDirectionality=True, shiftValue=0, maxSpacing=50, withFlag="",
68                  normalize=True, listPeak=False, reportshift=False, stringency=1.0):
69
70         self.statistics = {"index": 0,
71                            "total": 0,
72                            "mIndex": 0,
73                            "mTotal": 0,
74                            "failed": 0,
75                            "badRegionTrim": 0
76         }
77
78         self.regionLabel = label
79         self.rnaSettings = False
80         self.controlRDSsize = 1
81         self.sampleRDSsize = 1
82         self.minRatio = minRatio
83         self.minPeak = minPeak
84         self.leftPlusRatio = leftPlusRatio
85         self.stranded = "both"
86         if strandfilter == "plus":
87             self.stranded = "+"
88             minPlusRatio = 0.9
89             maxPlusRatio = 1.0
90         elif strandfilter == "minus":
91             self.stranded = "-"
92             minPlusRatio = 0.0
93             maxPlusRatio = 0.1
94
95         if minRatio < minPeak:
96             self.minPeak = minRatio
97
98         self.minPlusRatio = minPlusRatio
99         self.maxPlusRatio = maxPlusRatio
100         self.strandfilter = strandfilter
101         self.minHits = minHits
102         self.trimValue = trimValue
103         self.doTrim = doTrim
104         self.doDirectionality = doDirectionality
105
106         if self.doTrim:
107             self.trimString = string.join(["%2.1f" % (100. * self.trimValue), "%"], "")
108         else:
109             self.trimString = "none"
110
111         self.shiftValue = shiftValue
112         self.maxSpacing = maxSpacing
113         self.withFlag = withFlag
114         self.normalize = normalize
115         self.listPeak = listPeak
116         self.reportshift = reportshift
117         self.stringency = max(stringency, 1.0)
118
119
120     def useRNASettings(self, readlen):
121         self.rnaSettings = True
122         self.shiftValue = 0
123         self.doTrim = False
124         self.doDirectionality = False
125         self.maxSpacing = readlen
126
127
128     def getHeader(self, doPvalue):
129         if self.normalize:
130             countType = "RPM"
131         else:
132             countType = "COUNT"
133
134         headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
135
136         if self.doDirectionality:
137             headerFields.append("plus%\tleftPlus%")
138
139         if self.listPeak:
140             headerFields.append("peakPos\tpeakHeight")
141
142         if self.reportshift:
143             headerFields.append("readShift")
144
145         if doPvalue:
146             headerFields.append("pValue")
147
148         return string.join(headerFields, "\t")
149
150
151     def printSettings(self, doRevBackground, ptype, doControl, useMulti, doCache, pValueType):
152         print
153         self.printStatusMessages(doRevBackground, ptype, doControl, useMulti)
154         self.printOptionsSummary(useMulti, doCache, pValueType)
155
156
157     def printStatusMessages(self, doRevBackground, ptype, doControl, useMulti):
158         if self.shiftValue == "learn":
159             print "Will try to learn shift"
160
161         if self.normalize:
162             print "Normalizing to RPM"
163
164         if doRevBackground:
165             print "Swapping IP and background to calculate FDR"
166
167         if ptype != "":
168             if ptype in ["NONE", "SELF"]:
169                 pass
170             elif ptype == "BACK":
171                 if doControl and doRevBackground:
172                     pass
173                 else:
174                     print "must have a control dataset and -revbackground for pValue type 'back'"
175             else:
176                 print "could not use pValue type : %s" % ptype
177
178         if self.withFlag != "":
179             print "restrict to flag = %s" % self.withFlag
180
181         if not useMulti:
182             print "using unique reads only"
183
184         if self.rnaSettings:
185             print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
186
187         if self.strandfilter == "plus":
188             print "only analyzing reads on the plus strand"
189         elif self.strandfilter == "minus":
190             print "only analyzing reads on the minus strand"
191
192
193     def printOptionsSummary(self, useMulti, doCache, pValueType):
194
195         print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, self.listPeak, not useMulti, doCache)
196         print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
197         try:
198             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
199         except:
200             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
201
202
203     def getAnalysisDescription(self, hitfile, useMulti, doCache, pValueType, controlfile, doControl):
204
205         description = ["#ERANGE %s" % versionString]
206         if doControl:
207             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, controlfile, self.controlRDSsize))
208         else:
209             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
210
211         if self.withFlag != "":
212             description.append("#restrict to Flag = %s" % self.withFlag)
213
214         description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, self.listPeak, not useMulti, doCache))
215         description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
216         try:
217             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
218         except:
219             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
220
221         return string.join(description, "\n")
222
223
224     def updateControlStatistics(self, peak, sumAll, peakScore):
225
226         plusRatio = float(peak.numPlus)/peak.numHits
227         if peakScore >= self.minPeak and self.minPlusRatio <= plusRatio <= self.maxPlusRatio:
228             if self.doDirectionality:
229                 if self.leftPlusRatio < peak.numLeftPlus / peak.numPlus:
230                     self.statistics["mIndex"] += 1
231                     self.statistics["mTotal"] += sumAll
232                 else:
233                     self.statistics["failed"] += 1
234             else:
235                 # we have a region, but didn't check for directionality
236                 self.statistics["mIndex"] += 1
237                 self.statistics["mTotal"] += sumAll
238
239
240 def usage():
241     print __doc__
242
243
244 def main(argv=None):
245     if not argv:
246         argv = sys.argv
247
248     parser = makeParser()
249     (options, args) = parser.parse_args(argv[1:])
250
251     if len(args) < 3:
252         usage()
253         sys.exit(2)
254
255     factor = args[0]
256     hitfile = args[1]
257     outfilename = args[2]
258
259     shiftValue = 0
260
261     if options.autoshift:
262         shiftValue = "auto"
263
264     if options.shift is not None:
265         try:
266             shiftValue = int(options.shift)
267         except ValueError:
268             if options.shift == "learn":
269                 shiftValue = "learn"
270
271     if options.noshift:
272         shiftValue = 0
273
274     if options.doAppend:
275         outputMode = "a"
276     else:
277         outputMode = "w"
278
279     regionFinder = RegionFinder(factor, minRatio=options.minRatio, minPeak=options.minPeak, minPlusRatio=options.minPlusRatio,
280                                 maxPlusRatio=options.maxPlusRatio, leftPlusRatio=options.leftPlusRatio, strandfilter=options.strandfilter,
281                                 minHits=options.minHits, trimValue=options.trimValue, doTrim=options.doTrim,
282                                 doDirectionality=options.doDirectionality, shiftValue=shiftValue, maxSpacing=options.maxSpacing,
283                                 withFlag=options.withFlag, normalize=options.normalize, listPeak=options.listPeak,
284                                 reportshift=options.reportshift, stringency=options.stringency)
285
286     findall(regionFinder, hitfile, outfilename, options.logfilename, outputMode, options.rnaSettings,
287             options.cachePages, options.ptype, options.controlfile, options.doRevBackground,
288             options.useMulti, options.combine5p)
289
290
291 def makeParser():
292     usage = __doc__
293
294     parser = optparse.OptionParser(usage=usage)
295     parser.add_option("--control", dest="controlfile")
296     parser.add_option("--minimum", type="float", dest="minHits")
297     parser.add_option("--ratio", type="float", dest="minRatio")
298     parser.add_option("--spacing", type="int", dest="maxSpacing")
299     parser.add_option("--listPeak", action="store_true", dest="listPeak")
300     parser.add_option("--shift", dest="shift")
301     parser.add_option("--learnFold", type="float", dest="stringency")
302     parser.add_option("--noshift", action="store_true", dest="noShift")
303     parser.add_option("--autoshift", action="store_true", dest="autoshift")
304     parser.add_option("--reportshift", action="store_true", dest="reportshift")
305     parser.add_option("--nomulti", action="store_false", dest="useMulti")
306     parser.add_option("--minPlus", type="float", dest="minPlusRatio")
307     parser.add_option("--maxPlus", type="float", dest="maxPlusRatio")
308     parser.add_option("--leftPlus", type="float", dest="leftPlusRatio")
309     parser.add_option("--minPeak", type="float", dest="minPeak")
310     parser.add_option("--raw", action="store_false", dest="normalize")
311     parser.add_option("--revbackground", action="store_true", dest="doRevBackground")
312     parser.add_option("--pvalue", dest="ptype")
313     parser.add_option("--nodirectionality", action="store_false", dest="doDirectionality")
314     parser.add_option("--strandfilter", dest="strandfilter")
315     parser.add_option("--trimvalue", type="float", dest="trimValue")
316     parser.add_option("--notrim", action="store_false", dest="doTrim")
317     parser.add_option("--cache", type="int", dest="cachePages")
318     parser.add_option("--log", dest="logfilename")
319     parser.add_option("--flag", dest="withFlag")
320     parser.add_option("--append", action="store_true", dest="doAppend")
321     parser.add_option("--RNA", action="store_true", dest="rnaSettings")
322     parser.add_option("--combine5p", action="store_true", dest="combine5p")
323
324     configParser = getConfigParser()
325     section = "findall"
326     minHits = getConfigFloatOption(configParser, section, "minHits", 4.0)
327     minRatio = getConfigFloatOption(configParser, section, "minRatio", 4.0)
328     maxSpacing = getConfigIntOption(configParser, section, "maxSpacing", 50)
329     listPeak = getConfigBoolOption(configParser, section, "listPeak", False)
330     shift = getConfigOption(configParser, section, "shift", None)
331     stringency = getConfigFloatOption(configParser, section, "stringency", 4.0)
332     noshift = getConfigBoolOption(configParser, section, "noshift", False)
333     autoshift = getConfigBoolOption(configParser, section, "autoshift", False)
334     reportshift = getConfigBoolOption(configParser, section, "reportshift", False)
335     minPlusRatio = getConfigFloatOption(configParser, section, "minPlusRatio", 0.25)
336     maxPlusRatio = getConfigFloatOption(configParser, section, "maxPlusRatio", 0.75)
337     leftPlusRatio = getConfigFloatOption(configParser, section, "leftPlusRatio", 0.3)
338     minPeak = getConfigFloatOption(configParser, section, "minPeak", 0.5)
339     normalize = getConfigBoolOption(configParser, section, "normalize", True)
340     logfilename = getConfigOption(configParser, section, "logfilename", "findall.log")
341     withFlag = getConfigOption(configParser, section, "withFlag", "")
342     doDirectionality = getConfigBoolOption(configParser, section, "doDirectionality", True)
343     trimValue = getConfigFloatOption(configParser, section, "trimValue", 0.1)
344     doTrim = getConfigBoolOption(configParser, section, "doTrim", True)
345     doAppend = getConfigBoolOption(configParser, section, "doAppend", False)
346     rnaSettings = getConfigBoolOption(configParser, section, "rnaSettings", False)
347     cachePages = getConfigOption(configParser, section, "cachePages", None)
348     ptype = getConfigOption(configParser, section, "ptype", "")
349     controlfile = getConfigOption(configParser, section, "controlfile", None)
350     doRevBackground = getConfigBoolOption(configParser, section, "doRevBackground", False)
351     useMulti = getConfigBoolOption(configParser, section, "useMulti", True)
352     strandfilter = getConfigOption(configParser, section, "strandfilter", "")
353     combine5p = getConfigBoolOption(configParser, section, "combine5p", False)
354
355     parser.set_defaults(minHits=minHits, minRatio=minRatio, maxSpacing=maxSpacing, listPeak=listPeak, shift=shift,
356                         stringency=stringency, noshift=noshift, autoshift=autoshift, reportshift=reportshift,
357                         minPlusRatio=minPlusRatio, maxPlusRatio=maxPlusRatio, leftPlusRatio=leftPlusRatio, minPeak=minPeak,
358                         normalize=normalize, logfilename=logfilename, withFlag=withFlag, doDirectionality=doDirectionality,
359                         trimValue=trimValue, doTrim=doTrim, doAppend=doAppend, rnaSettings=rnaSettings,
360                         cachePages=cachePages, ptype=ptype, controlfile=controlfile, doRevBackground=doRevBackground, useMulti=useMulti,
361                         strandfilter=strandfilter, combine5p=combine5p)
362
363     return parser
364
365
366 def findall(regionFinder, hitfile, outfilename, logfilename="findall.log", outputMode="w", rnaSettings=False, cachePages=None,
367             ptype="", controlfile=None, doRevBackground=False, useMulti=True, combine5p=False):
368
369     writeLog(logfilename, versionString, string.join(sys.argv[1:]))
370     doCache = cachePages is not None
371     controlRDS = None
372     doControl = controlfile is not None
373     if doControl:
374         print "\ncontrol:" 
375         controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
376         regionFinder.controlRDSsize = len(controlRDS) / 1000000.
377
378     print "\nsample:" 
379     hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
380     regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
381     pValueType = getPValueType(ptype, doControl, doRevBackground)
382     doPvalue = not pValueType == "none"
383     regionFinder.readlen = hitRDS.getReadSize()
384     if rnaSettings:
385         regionFinder.useRNASettings(regionFinder.readlen)
386
387     regionFinder.printSettings(doRevBackground, ptype, doControl, useMulti, doCache, pValueType)
388     outfile = open(outfilename, outputMode)
389     header = writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl)
390     shiftDict = {}
391     chromosomeList = getChromosomeListToProcess(hitRDS, controlRDS, doControl)
392     for chromosome in chromosomeList:
393         #TODO: QForAli -Really? Use first chr shift value for all of them
394         if regionFinder.shiftValue == "learn":
395             learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
396
397         allregions, outregions = findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename, outfile, useMulti, doControl, controlRDS, combine5p)
398         if doRevBackground:
399             backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti)
400             writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header, backregions=backregions, pValueType=pValueType)
401         else:
402             writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict, allregions, header)
403
404     footer = getFooter(regionFinder, shiftDict, doRevBackground)
405     print footer
406     print >> outfile, footer
407     outfile.close()
408     writeLog(logfilename, versionString, outfilename + footer.replace("\n#"," | ")[:-1])
409
410
411 def getPValueType(ptype, doControl, doRevBackground):
412     pValueType = "self"
413     if ptype in ["NONE", "SELF", "BACK"]:
414         if ptype == "NONE":
415             pValueType = "none"
416         elif ptype == "SELF":
417             pValueType = "self"
418         elif ptype == "BACK":
419             if doControl and doRevBackground:
420                 pValueType = "back"
421     elif doRevBackground:
422         pValueType = "back"
423
424     return pValueType
425
426
427 def openRDSFile(filename, cachePages=None, doCache=False):
428     rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
429     if cachePages > rds.getDefaultCacheSize():
430         rds.setDBcache(cachePages)
431
432     return rds
433
434
435 def writeOutputFileHeader(regionFinder, outfile, hitfile, useMulti, doCache, pValueType, doPvalue, controlfile, doControl):
436     print >> outfile, regionFinder.getAnalysisDescription(hitfile, useMulti, doCache, pValueType, controlfile, doControl)
437     header = regionFinder.getHeader(doPvalue)
438     print >> outfile, header
439
440     return header
441
442
443 def getChromosomeListToProcess(hitRDS, controlRDS=None, doControl=False):
444     hitChromList = hitRDS.getChromosomes()
445     if doControl:
446         controlChromList = controlRDS.getChromosomes()
447         chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
448     else:
449         chromosomeList = [chrom for chrom in hitChromList if chrom != "chrM"]
450
451     return chromosomeList
452
453
454 def findPeakRegions(regionFinder, hitRDS, chromosome, logfilename, outfilename,
455                     outfile, useMulti, doControl, controlRDS, combine5p):
456
457     outregions = []
458     allregions = []
459     print "chromosome %s" % (chromosome)
460     previousHit = - 1 * regionFinder.maxSpacing
461     readStartPositions = [-1]
462     totalWeight = 0
463     uniqueReadCount = 0
464     reads = []
465     numStarts = 0
466     badRegion = False
467     hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
468                                   strand=regionFinder.stranded, combine5p=combine5p)
469
470     maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
471     for read in hitDict[chromosome]:
472         pos = read["start"]
473         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
474             lastReadPos = readStartPositions[-1]
475             lastBasePosition = lastReadPos + regionFinder.readlen - 1
476             newRegionIndex = regionFinder.statistics["index"] + 1
477             if regionFinder.doDirectionality:
478                 region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel,
479                                                   numReads=totalWeight)
480             else:
481                 region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=regionFinder.regionLabel, numReads=totalWeight)
482
483             if regionFinder.normalize:
484                 region.numReads /= regionFinder.sampleRDSsize
485
486             allregions.append(int(region.numReads))
487             regionLength = lastReadPos - region.start
488             if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
489                 region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
490
491                 if region.foldRatio >= regionFinder.minRatio:
492                     # first pass, with absolute numbers
493                     peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
494                     if regionFinder.doTrim:
495                         try:
496                             lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.sampleRDSsize)
497                         except IndexError:
498                             badRegion = True
499                             continue
500
501                         region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos, useMulti, doControl)
502
503                     # just in case it changed, use latest data
504                     try:
505                         bestPos = peak.topPos[0]
506                         peakScore = peak.smoothArray[bestPos]
507                         if regionFinder.normalize:
508                             peakScore /= regionFinder.sampleRDSsize
509                     except:
510                         continue
511
512                     if regionFinder.listPeak:
513                         region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)
514
515                     if useMulti:
516                         setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos,
517                                                regionFinder.normalize, regionFinder.doTrim)
518
519                     region.shift = peak.shift
520                     # check that we still pass threshold
521                     regionLength = lastReadPos - region.start
522                     plusRatio = float(peak.numPlus)/peak.numHits
523                     if regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
524                         try:
525                             updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
526                             regionFinder.statistics["index"] += 1
527                             outregions.append(region)
528                             regionFinder.statistics["total"] += region.numReads
529                         except RegionDirectionError:
530                             regionFinder.statistics["failed"] += 1
531
532             readStartPositions = []
533             totalWeight = 0
534             uniqueReadCount = 0
535             reads = []
536             numStarts = 0
537             if badRegion:
538                 badRegion = False
539                 regionFinder.statistics["badRegionTrim"] += 1
540
541         if pos not in readStartPositions:
542             numStarts += 1
543
544         readStartPositions.append(pos)
545         weight = read["weight"]
546         totalWeight += weight
547         if weight == 1.0:
548             uniqueReadCount += 1
549
550         reads.append({"start": pos, "sense": read["sense"], "weight": weight})
551         previousHit = pos
552
553     return allregions, outregions
554
555
556 def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, useMulti):
557     #TODO: this is *almost* the same calculation - there are small yet important differences
558     print "calculating background..."
559     previousHit = - 1 * regionFinder.maxSpacing
560     currentHitList = [-1]
561     currentTotalWeight = 0
562     currentReadList = []
563     backregions = []
564     numStarts = 0
565     badRegion = False
566     hitDict = controlRDS.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, findallOptimize=True)
567     maxCoord = controlRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
568     for read in hitDict[chromosome]:
569         pos = read["start"]
570         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
571             lastReadPos = currentHitList[-1]
572             lastBasePosition = lastReadPos + regionFinder.readlen - 1
573             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
574             if regionFinder.normalize:
575                 region.numReads /= regionFinder.controlRDSsize
576
577             backregions.append(int(region.numReads))
578             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=regionFinder.regionLabel, numReads=currentTotalWeight)
579             regionLength = lastReadPos - region.start
580             if regionPassesCriteria(regionFinder, region.numReads, numStarts, regionLength):
581                 numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
582                 if regionFinder.normalize:
583                     numMock /= regionFinder.sampleRDSsize
584
585                 foldRatio = region.numReads / numMock
586                 if foldRatio >= regionFinder.minRatio:
587                     # first pass, with absolute numbers
588                     peak = findPeak(currentReadList, region.start, lastReadPos - region.start, regionFinder.readlen, doWeight=True,
589                                     leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
590
591                     if regionFinder.doTrim:
592                         try:
593                             lastReadPos = trimRegion(region, regionFinder, peak, lastReadPos, 20., currentReadList, regionFinder.controlRDSsize)
594                         except IndexError:
595                             badRegion = True
596                             continue
597
598                         numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
599                         if regionFinder.normalize:
600                             numMock /= regionFinder.sampleRDSsize
601
602                         foldRatio = region.numReads / numMock
603
604                     # just in case it changed, use latest data
605                     try:
606                         bestPos = peak.topPos[0]
607                         peakScore = peak.smoothArray[bestPos]
608                     except IndexError:
609                         continue
610
611                     # normalize to RPM
612                     if regionFinder.normalize:
613                         peakScore /= regionFinder.controlRDSsize
614
615                     # check that we still pass threshold
616                     regionLength = lastReadPos - region.start
617                     if regionPassesCriteria(regionFinder, region.numReads, foldRatio, regionLength):
618                         regionFinder.updateControlStatistics(peak, region.numReads, peakScore)
619
620             currentHitList = []
621             currentTotalWeight = 0
622             currentReadList = []
623             numStarts = 0
624             if badRegion:
625                 badRegion = False
626                 regionFinder.statistics["badRegionTrim"] += 1
627
628         if pos not in currentHitList:
629             numStarts += 1
630
631         currentHitList.append(pos)
632         weight = read["weight"]
633         currentTotalWeight += weight
634         currentReadList.append({"start": pos, "sense": read["sense"], "weight": weight})
635         previousHit = pos
636
637     return backregions
638
639
640 def learnShift(regionFinder, hitRDS, chromosome, logfilename, outfilename,
641                outfile, useMulti, doControl, controlRDS, combine5p):
642
643     hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=regionFinder.withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
644                                   strand=regionFinder.stranded, combine5p=combine5p)
645
646     maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
647     print "learning shift.... will need at least 30 training sites"
648     stringency = regionFinder.stringency
649     previousHit = -1 * regionFinder.maxSpacing
650     positionList = [-1]
651     totalWeight = 0
652     readList = []
653     shiftDict = {}
654     count = 0
655     numStarts = 0
656     for read in hitDict[chromosome]:
657         pos = read["start"]
658         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
659             if regionFinder.normalize:
660                 totalWeight /= regionFinder.sampleRDSsize
661
662             regionStart = positionList[0]
663             regionStop = positionList[-1]
664             regionLength = regionStop - regionStart
665             if regionPassesCriteria(regionFinder, totalWeight, numStarts, regionLength, stringency=stringency):
666                 foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl)
667                 if foldRatio >= regionFinder.minRatio:
668                     updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
669                     count += 1
670
671             positionList = []
672             totalWeight = 0
673             readList = []
674
675         if pos not in positionList:
676             numStarts += 1
677
678         positionList.append(pos)
679         weight = read["weight"]
680         totalWeight += weight
681         readList.append({"start": pos, "sense": read["sense"], "weight": weight})
682         previousHit = pos
683
684     outline = "#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d\n#number of training examples: %d" % (stringency,
685                                                                                                                                stringency * regionFinder.minHits,
686                                                                                                                                stringency * regionFinder.minRatio,
687                                                                                                                                stringency * regionFinder.readlen,
688                                                                                                                                count)
689
690     print outline
691     writeLog(logfilename, versionString, outfilename + outline)
692     regionFinder.shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
693     outline = "#picked shiftValue to be %d" % regionFinder.shiftValue
694     print outline
695     print >> outfile, outline
696     writeLog(logfilename, versionString, outfilename + outline)
697
698
699 def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
700     return abs(pos - previousHit) > maxSpacing or pos == maxCoord
701
702
703 def regionPassesCriteria(regionFinder, sumAll, numStarts, regionLength, stringency=1):
704     minTotalReads = stringency * regionFinder.minHits
705     minNumReadStarts = stringency * regionFinder.minRatio
706     minRegionLength = stringency * regionFinder.readlen
707
708     return sumAll >= minTotalReads and numStarts > minNumReadStarts and regionLength > minRegionLength
709
710
711 def trimRegion(region, regionFinder, peak, regionStop, trimValue, currentReadList, totalReadCount):
712     bestPos = peak.topPos[0]
713     peakScore = peak.smoothArray[bestPos]
714     if regionFinder.normalize:
715         peakScore /= totalReadCount
716
717     minSignalThresh = trimValue * peakScore
718     start = findStartEdgePosition(peak, minSignalThresh)
719     regionEndPoint = regionStop - region.start - 1
720     stop = findStopEdgePosition(peak, regionEndPoint, minSignalThresh)
721
722     regionStop = region.start + stop
723     region.start += start
724
725     trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, regionFinder.readlen, doWeight=True,
726                            leftPlus=regionFinder.doDirectionality, shift=peak.shift)
727
728     peak.numPlus = trimmedPeak.numPlus
729     peak.numLeftPlus = trimmedPeak.numLeftPlus
730     peak.topPos = trimmedPeak.topPos
731     peak.smoothArray = trimmedPeak.smoothArray
732
733     region.numReads = trimmedPeak.numHits
734     if regionFinder.normalize:
735         region.numReads /= totalReadCount
736
737     region.stop = regionStop + regionFinder.readlen - 1
738                           
739     return regionStop
740
741
742 def findStartEdgePosition(peak, minSignalThresh):
743     start = 0
744     while not peakEdgeLocated(peak, start, minSignalThresh):
745         start += 1
746
747     return start
748
749
750 def findStopEdgePosition(peak, stop, minSignalThresh):
751     while not peakEdgeLocated(peak, stop, minSignalThresh):
752         stop -= 1
753
754     return stop
755
756
757 def peakEdgeLocated(peak, position, minSignalThresh):
758     return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
759
760
761 def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl):
762     """ Fold ratio calculated is total read weight over control
763     """
764     #TODO: this needs to be generalized as there is a point at which we want to use the sampleRDS instead of controlRDS
765     if doControl:
766         numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
767         if regionFinder.normalize:
768             numMock /= regionFinder.controlRDSsize
769
770         foldRatio = sumAll / numMock
771     else:
772         foldRatio = regionFinder.minRatio
773
774     return foldRatio
775
776
777 def updateShiftDict(shiftDict, readList, regionStart, regionLength, readlen):
778     peak = findPeak(readList, regionStart, regionLength, readlen, doWeight=True, shift="auto")
779     try:
780         shiftDict[peak.shift] += 1
781     except KeyError:
782         shiftDict[peak.shift] = 1
783
784
785 def getShiftValue(shiftDict, count, logfilename, outfilename):
786     if count < 30:
787         outline = "#too few training examples to pick a shiftValue - defaulting to 0\n#consider picking a lower minimum or threshold"
788         print outline
789         writeLog(logfilename, versionString, outfilename + outline)
790         shiftValue = 0
791     else:
792         shiftValue = getBestShiftInDict(shiftDict)
793         print shiftDict
794
795     return shiftValue
796
797
798 def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
799               peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
800               numPlus, plusRatio):
801
802     if doDirectionality:
803         if leftPlusRatio < numLeft / numPlus:
804             plusP = plusRatio * 100.
805             leftP = 100. * numLeft / numPlus
806             # we have a region that passes all criteria
807             region = Region.DirectionalRegion(regionStart, regionStop,
808                                               factor, index, chromosome, sumAll,
809                                               foldRatio, multiP, plusP, leftP,
810                                               peakDescription, shift)
811
812         else:
813             raise RegionDirectionError
814     else:
815         # we have a region, but didn't check for directionality
816         region = Region.Region(regionStart, regionStop, factor, index, chromosome,
817                                sumAll, foldRatio, multiP, peakDescription, shift)
818
819     return region
820
821
822 def setMultireadPercentage(region, hitRDS, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
823     if doTrim:
824         sumMulti = hitRDS.getMultiCount(chromosome, region.start, lastReadPos)
825     else:
826         sumMulti = currentTotalWeight - currentUniqueCount
827
828     # normalize to RPM
829     if normalize:
830         sumMulti /= hitRDSsize
831
832     try:
833         multiP = 100. * (sumMulti / region.numReads)
834     except ZeroDivisionError:
835         return
836
837     region.multiP = min(multiP, 100.)
838
839
840 def regionAndPeakPass(regionFinder, region, regionLength, peakScore, plusRatio):
841     regionPasses = False
842     if regionPassesCriteria(regionFinder, region.numReads, region.foldRatio, regionLength):
843         if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
844             regionPasses = True
845
846     return regionPasses
847
848
849 def updateRegion(region, doDirectionality, leftPlusRatio, numLeft, numPlus, plusRatio):
850
851     if doDirectionality:
852         if leftPlusRatio < numLeft / numPlus:
853             region.plusP = plusRatio * 100.
854             region.leftP = 100. * numLeft / numPlus
855         else:
856             raise RegionDirectionError
857
858
859 def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
860                                 allregions, header):
861
862     writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
863                            allregions, header, backregions=[], pValueType="self")
864
865
866 def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, shiftDict,
867                            allregions, header, backregions=[], pValueType="none"):
868
869     print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
870     if doPvalue:
871         if pValueType == "self":
872             poissonmean = calculatePoissonMean(allregions)
873         else:
874             poissonmean = calculatePoissonMean(backregions)
875
876     print header
877     writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=regionFinder.reportshift, shiftDict=shiftDict)
878
879
880 def calculatePoissonMean(dataList):
881     dataList.sort()
882     listSize = float(len(dataList))
883     try:
884         poissonmean = sum(dataList) / listSize
885     except ZeroDivisionError:
886         poissonmean = 0
887
888     print "Poisson n=%d, p=%f" % (listSize, poissonmean)
889
890     return poissonmean
891
892
893 def writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=0, reportshift=False, shiftDict={}):
894     for region in outregions:
895         if shiftValue == "auto" and reportshift:
896             try:
897                 shiftDict[region.shift] += 1
898             except KeyError:
899                 shiftDict[region.shift] = 1
900
901         outline = getRegionString(region, reportshift)
902
903         # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
904         if doPvalue:
905             sumAll = int(region.numReads)
906             pValue = calculatePValue(sumAll, poissonmean)
907             outline += "\t%1.2g" % pValue
908
909         print outline
910         print >> outfile, outline
911
912
913 def calculatePValue(sum, poissonmean):
914     pValue = math.exp(-poissonmean)
915     #TODO: 798: DeprecationWarning: integer argument expected, got float - for i in xrange(sum)
916     for i in xrange(sum):
917         pValue *= poissonmean
918         pValue /= i+1
919
920     return pValue
921
922
923 def getRegionString(region, reportShift):
924     if reportShift:
925         outline = region.printRegionWithShift()
926     else:
927         outline = region.printRegion()
928
929     return outline
930
931
932 def getFooter(regionFinder, shiftDict, doRevBackground):
933     index = regionFinder.statistics["index"]
934     mIndex = regionFinder.statistics["mIndex"]
935     footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
936     if regionFinder.doDirectionality:
937         footerLines.append("#\t\t%d additional regions failed directionality filter" % regionFinder.statistics["failed"])
938
939     if doRevBackground:
940         try:
941             percent = min(100. * (float(mIndex)/index), 100.)
942         except ZeroDivisionError:
943             percent = 0.
944
945         footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))
946
947     if regionFinder.shiftValue == "auto" and regionFinder.reportshift:
948         bestShift = getBestShiftInDict(shiftDict)
949         footerLines.append("#mode of shift values: %d" % bestShift)
950
951     if regionFinder.statistics["badRegionTrim"] > 0:
952         footerLines.append("#%d regions discarded due to trimming problems" % regionFinder.statistics["badRegionTrim"])
953
954     return string.join(footerLines, "\n")
955
956
957 def getBestShiftInDict(shiftDict):
958     return max(shiftDict.iteritems(), key=operator.itemgetter(1))[0]
959
960
961 if __name__ == "__main__":
962     main(sys.argv)