663d7e075d5d30535da6eed6547d4da2b4970be2
[erange.git] / findall.py
1 """
2     usage: python $ERANGEPATH/findall.py label samplerdsfile regionoutfile
3            [--control controlrdsfile] [--minimum minHits] [--ratio minRatio]
4            [--spacing maxSpacing] [--listPeak] [--shift #bp | learn] [--learnFold num]
5            [--noshift] [--autoshift] [--reportshift] [--nomulti] [--minPlus fraction]
6            [--maxPlus fraction] [--leftPlus fraction] [--minPeak RPM] [--raw]
7            [--revbackground] [--pvalue self|back|none] [--nodirectionality]
8            [--strandfilter plus/minus] [--trimvalue percent] [--notrim]
9            [--cache pages] [--log altlogfile] [--flag aflag] [--append] [--RNA]
10
11            where values in brackets are optional and label is an arbitrary string.
12
13            Use -ratio (default 4 fold) to set the minimum fold enrichment
14            over the control, -minimum (default 4) is the minimum number of reads
15            (RPM) within the region, and -spacing (default readlen) to set the maximum
16            distance between reads in the region. -listPeak lists the peak of the
17            region. Peaks mut be higher than -minPeak (default 0.5 RPM).
18            Pvalues are calculated from the sample (change with -pvalue),
19            unless the -revbackground flag and a control RDS file are provided.
20
21            By default, all numbers and parameters are on a reads per
22            million (RPM) basis. -raw will treat all settings, ratios and reported
23            numbers as raw counts rather than RPM. Use -notrim to turn off region
24            trimming and -trimvalue to control trimming (default 10% of peak signal)
25
26            The peak finder uses minimal directionality information that can
27            be turned off with -nodirectionality ; the fraction of + strand reads
28            required to be to the left of the peak (default 0.3) can be set with
29            -leftPlus ; -minPlus and -maxPlus change the minimum/maximum fraction
30            of plus reads in a region, which (defaults 0.25 and 0.75, respectively).
31
32            Use -shift to shift reads either by half the expected
33            fragment length (default 0 bp) or '-shift learn ' to learn the shift
34            based on the first chromosome. If you prefer to learn the shift
35            manually, use -autoshift to calculate a per-region shift value, which
36            can be reported using -reportshift. -strandfilter should only be used
37            when explicitely calling unshifted stranded peaks from non-ChIP-seq
38            data such as directional RNA-seq. regionoutfile is written over by
39            default unless given the -append flag.
40 """
41
42 try:
43     import psyco
44     psyco.full()
45 except:
46     pass
47
48 import sys
49 import math
50 import string
51 import optparse
52 import operator
53 from commoncode import writeLog, findPeak, getBestShiftForRegion, getConfigParser, getConfigOption, getConfigIntOption, getConfigFloatOption, getConfigBoolOption
54 import ReadDataset
55 import Region
56
57
58 versionString = "findall: version 3.2.1"
59 print versionString
60
61 class RegionDirectionError(Exception):
62     pass
63
64
65 class RegionFinder():
66     def __init__(self, minRatio, minPeak, minPlusRatio, maxPlusRatio, leftPlusRatio, strandfilter, minHits, trimValue, doTrim, doDirectionality,
67                  shiftValue, maxSpacing):
68
69         self.statistics = {"index": 0,
70                            "total": 0,
71                            "mIndex": 0,
72                            "mTotal": 0,
73                            "failed": 0,
74                            "badRegionTrim": 0
75         }
76
77         self.controlRDSsize = 1
78         self.sampleRDSsize = 1
79         self.minRatio = minRatio
80         self.minPeak = minPeak
81         self.leftPlusRatio = leftPlusRatio
82         self.stranded = "both"
83         if strandfilter == "plus":
84             self.stranded = "+"
85             minPlusRatio = 0.9
86             maxPlusRatio = 1.0
87         elif strandfilter == "minus":
88             self.stranded = "-"
89             minPlusRatio = 0.0
90             maxPlusRatio = 0.1
91
92         if minRatio < minPeak:
93             self.minPeak = minRatio
94
95         self.minPlusRatio = minPlusRatio
96         self.maxPlusRatio = maxPlusRatio
97         self.strandfilter = strandfilter
98         self.minHits = minHits
99         self.trimValue = trimValue
100         self.doTrim = doTrim
101         self.doDirectionality = doDirectionality
102
103         if self.doTrim:
104             self.trimString = string.join(["%2.1f" % (100. * self.trimValue), "%"], "")
105         else:
106             self.trimString = "none"
107
108         self.shiftValue = shiftValue
109         self.maxSpacing = maxSpacing
110
111
112     def useRNASettings(self):
113         self.shiftValue = 0
114         self.doTrim = False
115         self.doDirectionality = False
116
117
118     def printOptionsSummary(self, listPeak, useMulti, doCache, pValueType):
119
120         print "\nenforceDirectionality=%s listPeak=%s nomulti=%s cache=%s " % (self.doDirectionality, listPeak, not useMulti, doCache)
121         print "spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f\ttrimmed=%s\tstrand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded)
122         try:
123             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
124         except:
125             print "minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType)
126
127
128     def getAnalysisDescription(self, hitfile, controlfile, doControl,
129                                  withFlag, listPeak, useMulti, doCache, pValueType):
130
131         description = ["#ERANGE %s" % versionString]
132         if doControl:
133             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample:\t%s (%.1f M reads)" % (hitfile, self.sampleRDSsize, controlfile, self.controlRDSsize))
134         else:
135             description.append("#enriched sample:\t%s (%.1f M reads)\n#control sample: none" % (hitfile, self.sampleRDSsize))
136
137         if withFlag != "":
138             description.append("#restrict to Flag = %s" % withFlag)
139
140         description.append("#enforceDirectionality=%s listPeak=%s nomulti=%s cache=%s" % (self.doDirectionality, listPeak, not useMulti, doCache))
141         description.append("#spacing<%d minimum>%.1f ratio>%.1f minPeak=%.1f trimmed=%s strand=%s" % (self.maxSpacing, self.minHits, self.minRatio, self.minPeak, self.trimString, self.stranded))
142         try:
143             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%d pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
144         except:
145             description.append("#minPlus=%.2f maxPlus=%.2f leftPlus=%.2f shift=%s pvalue=%s" % (self.minPlusRatio, self.maxPlusRatio, self.leftPlusRatio, self.shiftValue, pValueType))
146
147         return string.join(description, "\n")
148
149
150 def usage():
151     print __doc__
152
153
154 def main(argv=None):
155     if not argv:
156         argv = sys.argv
157
158     parser = makeParser()
159     (options, args) = parser.parse_args(argv[1:])
160
161     if len(args) < 3:
162         usage()
163         sys.exit(2)
164
165     factor = args[0]
166     hitfile = args[1]
167     outfilename = args[2]
168
169     shiftValue = 0
170
171     if options.autoshift:
172         shiftValue = "auto"
173
174     if options.shift is not None:
175         try:
176             shiftValue = int(options.shift)
177         except ValueError:
178             if options.shift == "learn":
179                 shiftValue = "learn"
180
181     if options.noshift:
182         shiftValue = 0
183
184     if options.doAppend:
185         outputMode = "a"
186     else:
187         outputMode = "w"
188
189     findall(factor, hitfile, outfilename, options.minHits, options.minRatio, options.maxSpacing, options.listPeak, shiftValue,
190             options.stringency, options.reportshift,
191             options.minPlusRatio, options.maxPlusRatio, options.leftPlusRatio, options.minPeak,
192             options.normalize, options.logfilename, options.withFlag, options.doDirectionality,
193             options.trimValue, options.doTrim, outputMode, options.rnaSettings,
194             options.cachePages, options.ptype, options.controlfile, options.doRevBackground, options.useMulti,
195             options.strandfilter, options.combine5p)
196
197
198 def makeParser():
199     usage = __doc__
200
201     parser = optparse.OptionParser(usage=usage)
202     parser.add_option("--control", dest="controlfile")
203     parser.add_option("--minimum", type="float", dest="minHits")
204     parser.add_option("--ratio", type="float", dest="minRatio")
205     parser.add_option("--spacing", type="int", dest="maxSpacing")
206     parser.add_option("--listPeak", action="store_true", dest="listPeak")
207     parser.add_option("--shift", dest="shift")
208     parser.add_option("--learnFold", type="float", dest="stringency")
209     parser.add_option("--noshift", action="store_true", dest="noShift")
210     parser.add_option("--autoshift", action="store_true", dest="autoshift")
211     parser.add_option("--reportshift", action="store_true", dest="reportshift")
212     parser.add_option("--nomulti", action="store_false", dest="useMulti")
213     parser.add_option("--minPlus", type="float", dest="minPlusRatio")
214     parser.add_option("--maxPlus", type="float", dest="maxPlusRatio")
215     parser.add_option("--leftPlus", type="float", dest="leftPlusRatio")
216     parser.add_option("--minPeak", type="float", dest="minPeak")
217     parser.add_option("--raw", action="store_false", dest="normalize")
218     parser.add_option("--revbackground", action="store_true", dest="doRevBackground")
219     parser.add_option("--pvalue", dest="ptype")
220     parser.add_option("--nodirectionality", action="store_false", dest="doDirectionality")
221     parser.add_option("--strandfilter", dest="strandfilter")
222     parser.add_option("--trimvalue", type="float", dest="trimValue")
223     parser.add_option("--notrim", action="store_false", dest="doTrim")
224     parser.add_option("--cache", type="int", dest="cachePages")
225     parser.add_option("--log", dest="logfilename")
226     parser.add_option("--flag", dest="withFlag")
227     parser.add_option("--append", action="store_true", dest="doAppend")
228     parser.add_option("--RNA", action="store_true", dest="rnaSettings")
229     parser.add_option("--combine5p", action="store_true", dest="combine5p")
230
231     configParser = getConfigParser()
232     section = "findall"
233     minHits = getConfigFloatOption(configParser, section, "minHits", 4.0)
234     minRatio = getConfigFloatOption(configParser, section, "minRatio", 4.0)
235     maxSpacing = getConfigIntOption(configParser, section, "maxSpacing", 50)
236     listPeak = getConfigBoolOption(configParser, section, "listPeak", False)
237     shift = getConfigOption(configParser, section, "shift", None)
238     stringency = getConfigFloatOption(configParser, section, "stringency", 4.0)
239     noshift = getConfigBoolOption(configParser, section, "noshift", False)
240     autoshift = getConfigBoolOption(configParser, section, "autoshift", False)
241     reportshift = getConfigBoolOption(configParser, section, "reportshift", False)
242     minPlusRatio = getConfigFloatOption(configParser, section, "minPlusRatio", 0.25)
243     maxPlusRatio = getConfigFloatOption(configParser, section, "maxPlusRatio", 0.75)
244     leftPlusRatio = getConfigFloatOption(configParser, section, "leftPlusRatio", 0.3)
245     minPeak = getConfigFloatOption(configParser, section, "minPeak", 0.5)
246     normalize = getConfigBoolOption(configParser, section, "normalize", True)
247     logfilename = getConfigOption(configParser, section, "logfilename", "findall.log")
248     withFlag = getConfigOption(configParser, section, "withFlag", "")
249     doDirectionality = getConfigBoolOption(configParser, section, "doDirectionality", True)
250     trimValue = getConfigFloatOption(configParser, section, "trimValue", 0.1)
251     doTrim = getConfigBoolOption(configParser, section, "doTrim", True)
252     doAppend = getConfigBoolOption(configParser, section, "doAppend", False)
253     rnaSettings = getConfigBoolOption(configParser, section, "rnaSettings", False)
254     cachePages = getConfigOption(configParser, section, "cachePages", None)
255     ptype = getConfigOption(configParser, section, "ptype", "")
256     controlfile = getConfigOption(configParser, section, "controlfile", None)
257     doRevBackground = getConfigBoolOption(configParser, section, "doRevBackground", False)
258     useMulti = getConfigBoolOption(configParser, section, "useMulti", True)
259     strandfilter = getConfigOption(configParser, section, "strandfilter", "")
260     combine5p = getConfigBoolOption(configParser, section, "combine5p", False)
261
262     parser.set_defaults(minHits=minHits, minRatio=minRatio, maxSpacing=maxSpacing, listPeak=listPeak, shift=shift,
263                         stringency=stringency, noshift=noshift, autoshift=autoshift, reportshift=reportshift,
264                         minPlusRatio=minPlusRatio, maxPlusRatio=maxPlusRatio, leftPlusRatio=leftPlusRatio, minPeak=minPeak,
265                         normalize=normalize, logfilename=logfilename, withFlag=withFlag, doDirectionality=doDirectionality,
266                         trimValue=trimValue, doTrim=doTrim, doAppend=doAppend, rnaSettings=rnaSettings,
267                         cachePages=cachePages, ptype=ptype, controlfile=controlfile, doRevBackground=doRevBackground, useMulti=useMulti,
268                         strandfilter=strandfilter, combine5p=combine5p)
269
270     return parser
271
272
273 def findall(factor, hitfile, outfilename, minHits=4.0, minRatio=4.0, maxSpacing=50, listPeak=False, shiftValue=0,
274             stringency=4.0, reportshift=False, minPlusRatio=0.25, maxPlusRatio=0.75, leftPlusRatio=0.3, minPeak=0.5,
275             normalize=True, logfilename="findall.log", withFlag="", doDirectionality=True, trimValue=0.1, doTrim=True,
276             outputMode="w", rnaSettings=False, cachePages=None, ptype="", controlfile=None, doRevBackground=False,
277             useMulti=True, strandfilter="", combine5p=False):
278
279     regionFinder = RegionFinder(minRatio, minPeak, minPlusRatio, maxPlusRatio, leftPlusRatio, strandfilter, minHits, trimValue,
280                                 doTrim, doDirectionality, shiftValue, maxSpacing)
281
282     doControl = controlfile is not None
283     pValueType = getPValueType(ptype, doControl, doRevBackground)
284     doPvalue = not pValueType == "none"
285
286     if rnaSettings:
287         regionFinder.useRNASettings()
288
289     doCache = cachePages is not None
290     printStatusMessages(regionFinder.shiftValue, normalize, doRevBackground, ptype, doControl, withFlag, useMulti, rnaSettings, regionFinder.strandfilter)
291     controlRDS = None
292     if doControl:
293         print "\ncontrol:" 
294         controlRDS = openRDSFile(controlfile, cachePages=cachePages, doCache=doCache)
295
296     print "\nsample:" 
297     hitRDS = openRDSFile(hitfile, cachePages=cachePages, doCache=doCache)
298
299     print
300     regionFinder.readlen = hitRDS.getReadSize()
301     if rnaSettings:
302         regionFinder.maxSpacing = regionFinder.readlen
303
304     writeLog(logfilename, versionString, string.join(sys.argv[1:]))
305     regionFinder.printOptionsSummary(listPeak, useMulti, doCache, pValueType)
306
307     regionFinder.sampleRDSsize = len(hitRDS) / 1000000.
308     if doControl:
309         regionFinder.controlRDSsize = len(controlRDS) / 1000000.
310
311     outfile = open(outfilename, outputMode)
312     print >> outfile, regionFinder.getAnalysisDescription(hitfile, controlfile, doControl,
313                                                           withFlag, listPeak, useMulti, doCache, pValueType)
314
315     header = getHeader(normalize, regionFinder.doDirectionality, listPeak, reportshift, doPvalue)
316     print >> outfile, header
317     if minRatio < minPeak:
318         minPeak = minRatio
319
320     shiftDict = {}
321     hitChromList = hitRDS.getChromosomes()
322     stringency = max(stringency, 1.0)
323     chromosomeList = getChromosomeListToProcess(hitChromList, controlRDS, doControl)
324     for chromosome in chromosomeList:
325         allregions, outregions = findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, outfilename,
326                                                  outfile, stringency, normalize, useMulti, doControl, withFlag, combine5p,
327                                                  factor, listPeak)
328
329         if doRevBackground:
330             #TODO: this is *almost* the same calculation - there are small yet important differences
331             backregions = findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normalize, useMulti, factor)
332             writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
333                                    allregions, header, backregions=backregions, pValueType=pValueType)
334         else:
335             writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
336                                         allregions, header)
337
338     footer = getFooter(regionFinder, shiftDict, reportshift, doRevBackground)
339     print footer
340     print >> outfile, footer
341     outfile.close()
342     writeLog(logfilename, versionString, outfilename + footer.replace("\n#"," | ")[:-1])
343
344
345 def getPValueType(ptype, doControl, doRevBackground):
346     pValueType = "self"
347     if ptype in ["NONE", "SELF", "BACK"]:
348         if ptype == "NONE":
349             pValueType = "none"
350         elif ptype == "SELF":
351             pValueType = "self"
352         elif ptype == "BACK":
353             if doControl and doRevBackground:
354                 pValueType = "back"
355     elif doRevBackground:
356         pValueType = "back"
357
358     return pValueType
359
360
361 def printStatusMessages(shiftValue, normalize, doRevBackground, ptype, doControl, withFlag, useMulti, rnaSettings, strandfilter):
362     if shiftValue == "learn":
363         print "Will try to learn shift"
364
365     if normalize:
366         print "Normalizing to RPM"
367
368     if doRevBackground:
369         print "Swapping IP and background to calculate FDR"
370
371     if ptype != "":
372         if ptype in ["NONE", "SELF"]:
373             pass
374         elif ptype == "BACK":
375             if doControl and doRevBackground:
376                 pass
377             else:
378                 print "must have a control dataset and -revbackground for pValue type 'back'"
379         else:
380             print "could not use pValue type : %s" % ptype
381
382     if withFlag != "":
383         print "restrict to flag = %s" % withFlag
384
385     if not useMulti:
386         print "using unique reads only"
387
388     if rnaSettings:
389         print "using settings appropriate for RNA: -nodirectionality -notrim -noshift"
390
391     if strandfilter == "plus":
392         print "only analyzing reads on the plus strand"
393     elif strandfilter == "minus":
394         print "only analyzing reads on the minus strand"
395
396
397 def openRDSFile(filename, cachePages=None, doCache=False):
398     rds = ReadDataset.ReadDataset(filename, verbose=True, cache=doCache)
399     if cachePages > rds.getDefaultCacheSize():
400         rds.setDBcache(cachePages)
401
402     return rds
403
404
405 def getHeader(normalize, doDirectionality, listPeak, reportshift, doPvalue):
406     if normalize:
407         countType = "RPM"
408     else:
409         countType = "COUNT"
410
411     headerFields = ["#regionID\tchrom\tstart\tstop", countType, "fold\tmulti%"]
412
413     if doDirectionality:
414         headerFields.append("plus%\tleftPlus%")
415
416     if listPeak:
417         headerFields.append("peakPos\tpeakHeight")
418
419     if reportshift:
420         headerFields.append("readShift")
421
422     if doPvalue:
423         headerFields.append("pValue")
424
425     return string.join(headerFields, "\t")
426
427
428 def getChromosomeListToProcess(hitChromList, controlRDS, doControl):
429     if doControl:
430         controlChromList = controlRDS.getChromosomes()
431         chromosomeList = [chrom for chrom in hitChromList if chrom in controlChromList and chrom != "chrM"]
432     else:
433         chromosomeList = [chrom for chrom in hitChromList if chrom != "chrM"]
434
435     return chromosomeList
436
437
438 def findPeakRegions(regionFinder, hitRDS, controlRDS, chromosome, logfilename, outfilename,
439                     outfile, stringency, normalize, useMulti, doControl, withFlag, combine5p, factor, listPeak):
440
441     outregions = []
442     allregions = []
443     print "chromosome %s" % (chromosome)
444     previousHit = - 1 * regionFinder.maxSpacing
445     readStartPositions = [-1]
446     totalWeight = 0
447     uniqueReadCount = 0
448     reads = []
449     numStarts = 0
450     badRegion = False
451     hitDict = hitRDS.getReadsDict(fullChrom=True, chrom=chromosome, flag=withFlag, withWeight=True, doMulti=useMulti, findallOptimize=True,
452                                   strand=regionFinder.stranded, combine5p=combine5p)
453
454     maxCoord = hitRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
455     if regionFinder.shiftValue == "learn":
456         learnShift(regionFinder, hitDict, controlRDS, chromosome, maxCoord, logfilename, outfilename,
457                                 outfile, numStarts, stringency, normalize, useMulti, doControl)
458
459     for read in hitDict[chromosome]:
460         pos = read["start"]
461         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
462             lastReadPos = readStartPositions[-1]
463             lastBasePosition = lastReadPos + regionFinder.readlen - 1
464             newRegionIndex = regionFinder.statistics["index"] + 1
465             if regionFinder.doDirectionality:
466                 region = Region.DirectionalRegion(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=factor, numReads=totalWeight)
467             else:
468                 region = Region.Region(readStartPositions[0], lastBasePosition, chrom=chromosome, index=newRegionIndex, label=factor, numReads=totalWeight)
469
470             if normalize:
471                 region.numReads /= regionFinder.sampleRDSsize
472
473             allregions.append(int(region.numReads))
474             regionLength = lastReadPos - region.start
475             if regionPassesCriteria(region.numReads, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen):
476                 region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos,
477                                                 useMulti, doControl, normalize)
478
479                 if region.foldRatio >= regionFinder.minRatio:
480                     # first pass, with absolute numbers
481                     peak = findPeak(reads, region.start, regionLength, regionFinder.readlen, doWeight=True, leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
482                     if regionFinder.doTrim:
483                         try:
484                             lastReadPos = trimRegion(region, peak, lastReadPos, regionFinder.trimValue, reads, regionFinder.readlen, regionFinder.doDirectionality,
485                                                      normalize, regionFinder.sampleRDSsize)
486                         except IndexError:
487                             badRegion = True
488                             continue
489
490                         region.foldRatio = getFoldRatio(regionFinder, controlRDS, region.numReads, chromosome, region.start, lastReadPos,
491                                                         useMulti, doControl, normalize)
492
493                     # just in case it changed, use latest data
494                     try:
495                         bestPos = peak.topPos[0]
496                         peakScore = peak.smoothArray[bestPos]
497                         if normalize:
498                             peakScore /= regionFinder.sampleRDSsize
499                     except:
500                         continue
501
502                     if listPeak:
503                         region.peakDescription= "%d\t%.1f" % (region.start + bestPos, peakScore)
504
505                     if useMulti:
506                         setMultireadPercentage(region, hitRDS, regionFinder.sampleRDSsize, totalWeight, uniqueReadCount, chromosome, lastReadPos, normalize, regionFinder.doTrim)
507
508                     region.shift = peak.shift
509                     # check that we still pass threshold
510                     regionLength = lastReadPos - region.start
511                     plusRatio = float(peak.numPlus)/peak.numHits
512                     if regionAndPeakPass(region, regionFinder.minHits, regionFinder.minRatio, regionLength, regionFinder.readlen, peakScore, regionFinder.minPeak,
513                                          regionFinder.minPlusRatio, regionFinder.maxPlusRatio, plusRatio):
514                         try:
515                             updateRegion(region, regionFinder.doDirectionality, regionFinder.leftPlusRatio, peak.numLeftPlus, peak.numPlus, plusRatio)
516                             regionFinder.statistics["index"] += 1
517                             outregions.append(region)
518                             regionFinder.statistics["total"] += region.numReads
519                         except RegionDirectionError:
520                             regionFinder.statistics["failed"] += 1
521
522             readStartPositions = []
523             totalWeight = 0
524             uniqueReadCount = 0
525             reads = []
526             numStarts = 0
527             if badRegion:
528                 badRegion = False
529                 regionFinder.statistics["badRegionTrim"] += 1
530
531         if pos not in readStartPositions:
532             numStarts += 1
533
534         readStartPositions.append(pos)
535         weight = read["weight"]
536         totalWeight += weight
537         if weight == 1.0:
538             uniqueReadCount += 1
539
540         reads.append({"start": pos, "sense": read["sense"], "weight": weight})
541         previousHit = pos
542
543     return allregions, outregions
544
545
546 def findBackgroundRegions(regionFinder, hitRDS, controlRDS, chromosome, normalize, useMulti, factor):
547
548     print "calculating background..."
549     previousHit = - 1 * regionFinder.maxSpacing
550     currentHitList = [-1]
551     currentTotalWeight = 0
552     currentReadList = []
553     backregions = []
554     numStarts = 0
555     badRegion = False
556     hitDict = controlRDS.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, findallOptimize=True)
557     maxCoord = controlRDS.getMaxCoordinate(chromosome, doMulti=useMulti)
558     for read in hitDict[chromosome]:
559         pos = read["start"]
560         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
561             lastReadPos = currentHitList[-1]
562             lastBasePosition = lastReadPos + regionFinder.readlen - 1
563             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=factor, numReads=currentTotalWeight)
564             if normalize:
565                 region.numReads /= regionFinder.controlRDSsize
566
567             backregions.append(int(region.numReads))
568             region = Region.Region(currentHitList[0], lastBasePosition, chrom=chromosome, label=factor, numReads=currentTotalWeight)
569             regionLength = lastReadPos - region.start
570             if regionPassesCriteria(region.numReads, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen):
571                 numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
572                 if normalize:
573                     numMock /= regionFinder.sampleRDSsize
574
575                 foldRatio = region.numReads / numMock
576                 if foldRatio >= regionFinder.minRatio:
577                     # first pass, with absolute numbers
578                     peak = findPeak(currentReadList, region.start, lastReadPos - region.start, regionFinder.readlen, doWeight=True,
579                                     leftPlus=regionFinder.doDirectionality, shift=regionFinder.shiftValue)
580
581                     if regionFinder.doTrim:
582                         try:
583                             lastReadPos = trimRegion(region, peak, lastReadPos, 20., currentReadList, regionFinder.readlen, regionFinder.doDirectionality,
584                                                      normalize, regionFinder.controlRDSsize)
585                         except IndexError:
586                             badRegion = True
587                             continue
588
589                         numMock = 1. + hitRDS.getCounts(chromosome, region.start, lastReadPos, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
590                         if normalize:
591                             numMock /= regionFinder.sampleRDSsize
592
593                         foldRatio = region.numReads / numMock
594
595                     # just in case it changed, use latest data
596                     try:
597                         bestPos = peak.topPos[0]
598                         peakScore = peak.smoothArray[bestPos]
599                     except IndexError:
600                         continue
601
602                     # normalize to RPM
603                     if normalize:
604                         peakScore /= regionFinder.controlRDSsize
605
606                     # check that we still pass threshold
607                     regionLength = lastReadPos - region.start
608                     if regionPassesCriteria(region.numReads, regionFinder.minHits, foldRatio, regionFinder.minRatio, regionLength, regionFinder.readlen):
609                         updateControlStatistics(regionFinder, peak, region.numReads, peakScore)
610
611             currentHitList = []
612             currentTotalWeight = 0
613             currentReadList = []
614             numStarts = 0
615             if badRegion:
616                 badRegion = False
617                 regionFinder.statistics["badRegionTrim"] += 1
618
619         if pos not in currentHitList:
620             numStarts += 1
621
622         currentHitList.append(pos)
623         weight = read["weight"]
624         currentTotalWeight += weight
625         currentReadList.append({"start": pos, "sense": read["sense"], "weight": weight})
626         previousHit = pos
627
628     return backregions
629
630
631 def learnShift(regionFinder, hitDict, controlRDS, chromosome, maxCoord, logfilename, outfilename,
632                outfile, numStarts, stringency, normalize, useMulti, doControl):
633
634     print "learning shift.... will need at least 30 training sites"
635     previousHit = -1 * regionFinder.maxSpacing
636     positionList = [-1]
637     totalWeight = 0
638     readList = []
639     shiftDict = {}
640     count = 0
641     for read in hitDict[chromosome]:
642         pos = read["start"]
643         if previousRegionIsDone(pos, previousHit, regionFinder.maxSpacing, maxCoord):
644             if normalize:
645                 totalWeight /= regionFinder.sampleRDSsize
646
647             regionStart = positionList[0]
648             regionStop = positionList[-1]
649             regionLength = regionStop - regionStart
650             if regionPassesCriteria(totalWeight, regionFinder.minHits, numStarts, regionFinder.minRatio, regionLength, regionFinder.readlen, stringency=stringency):
651                 foldRatio = getFoldRatio(regionFinder, controlRDS, totalWeight, chromosome, regionStart, regionStop, useMulti, doControl, normalize)
652                 if foldRatio >= regionFinder.minRatio:
653                     updateShiftDict(shiftDict, readList, regionStart, regionLength, regionFinder.readlen)
654                     count += 1
655
656             positionList = []
657             totalWeight = 0
658             readList = []
659
660         if pos not in positionList:
661             numStarts += 1
662
663         positionList.append(pos)
664         weight = read["weight"]
665         totalWeight += weight
666         readList.append({"start": pos, "sense": read["sense"], "weight": weight})
667         previousHit = pos
668
669     outline = "#learn: stringency=%.2f min_signal=%2.f min_ratio=%.2f min_region_size=%d\n#number of training examples: %d" % (stringency,
670                                                                                                                                stringency * regionFinder.minHits,
671                                                                                                                                stringency * regionFinder.minRatio,
672                                                                                                                                stringency * regionFinder.readlen,
673                                                                                                                                count)
674
675     print outline
676     writeLog(logfilename, versionString, outfilename + outline)
677     regionFinder.shiftValue = getShiftValue(shiftDict, count, logfilename, outfilename)
678     outline = "#picked shiftValue to be %d" % regionFinder.shiftValue
679     print outline
680     print >> outfile, outline
681     writeLog(logfilename, versionString, outfilename + outline)
682
683
684 def previousRegionIsDone(pos, previousHit, maxSpacing, maxCoord):
685     return abs(pos - previousHit) > maxSpacing or pos == maxCoord
686
687
688 def regionPassesCriteria(sumAll, minHits, numStarts, minRatio, regionLength, readlen, stringency=1):
689     return sumAll >= stringency * minHits and numStarts > stringency * minRatio and regionLength > stringency * readlen
690
691
692 def trimRegion(region, peak, regionStop, trimValue, currentReadList, readlen, doDirectionality, normalize, hitRDSsize):
693     bestPos = peak.topPos[0]
694     peakScore = peak.smoothArray[bestPos]
695     if normalize:
696         peakScore /= hitRDSsize
697
698     minSignalThresh = trimValue * peakScore
699     start = findStartEdgePosition(peak, minSignalThresh)
700     regionEndPoint = regionStop - region.start - 1
701     stop = findStopEdgePosition(peak, regionEndPoint, minSignalThresh)
702
703     regionStop = region.start + stop
704     region.start += start
705     trimmedPeak = findPeak(currentReadList, region.start, regionStop - region.start, readlen, doWeight=True, leftPlus=doDirectionality, shift=peak.shift)
706     peak.numPlus = trimmedPeak.numPlus
707     peak.numLeftPlus = trimmedPeak.numLeftPlus
708     peak.topPos = trimmedPeak.topPos
709     peak.smoothArray = trimmedPeak.smoothArray
710
711     region.numReads = trimmedPeak.numHits
712     if normalize:
713         region.numReads /= hitRDSsize
714
715     region.stop = regionStop + readlen - 1
716                                 
717     return regionStop
718
719
720 def findStartEdgePosition(peak, minSignalThresh):
721     start = 0
722     while not peakEdgeLocated(peak, start, minSignalThresh):
723         start += 1
724
725     return start
726
727
728 def findStopEdgePosition(peak, stop, minSignalThresh):
729     while not peakEdgeLocated(peak, stop, minSignalThresh):
730         stop -= 1
731
732     return stop
733
734
735 def peakEdgeLocated(peak, position, minSignalThresh):
736     return peak.smoothArray[position] >= minSignalThresh or position == peak.topPos[0]
737
738
739 def getFoldRatio(regionFinder, controlRDS, sumAll, chromosome, regionStart, regionStop, useMulti, doControl, normalize):
740     if doControl:
741         numMock = 1. + controlRDS.getCounts(chromosome, regionStart, regionStop, uniqs=True, multi=useMulti, splices=False, reportCombined=True)
742         if normalize:
743             numMock /= regionFinder.controlRDSsize
744
745         foldRatio = sumAll / numMock
746     else:
747         foldRatio = regionFinder.minRatio
748
749     return foldRatio
750
751
752 def updateShiftDict(shiftDict, readList, regionStart, regionLength, readlen):
753     peak = findPeak(readList, regionStart, regionLength, readlen, doWeight=True, shift="auto")
754     try:
755         shiftDict[peak.shift] += 1
756     except KeyError:
757         shiftDict[peak.shift] = 1
758
759
760 def getShiftValue(shiftDict, count, logfilename, outfilename):
761     if count < 30:
762         outline = "#too few training examples to pick a shiftValue - defaulting to 0\n#consider picking a lower minimum or threshold"
763         print outline
764         writeLog(logfilename, versionString, outfilename + outline)
765         shiftValue = 0
766     else:
767         shiftValue = getBestShiftInDict(shiftDict)
768         print shiftDict
769
770     return shiftValue
771
772
773 def updateControlStatistics(regionFinder, peak, sumAll, peakScore):
774
775     plusRatio = float(peak.numPlus)/peak.numHits
776     if peakScore >= regionFinder.minPeak and regionFinder.minPlusRatio <= plusRatio <= regionFinder.maxPlusRatio:
777         if regionFinder.doDirectionality:
778             if regionFinder.leftPlusRatio < peak.numLeft / peak.numPlus:
779                 regionFinder.statistics["mIndex"] += 1
780                 regionFinder.statistics["mTotal"] += sumAll
781             else:
782                 regionFinder.statistics["failed"] += 1
783         else:
784             # we have a region, but didn't check for directionality
785             regionFinder.statistics["mIndex"] += 1
786             regionFinder.statistics["mTotal"] += sumAll
787
788
789 def getRegion(regionStart, regionStop, factor, index, chromosome, sumAll, foldRatio, multiP,
790               peakDescription, shift, doDirectionality, leftPlusRatio, numLeft,
791               numPlus, plusRatio):
792
793     if doDirectionality:
794         if leftPlusRatio < numLeft / numPlus:
795             plusP = plusRatio * 100.
796             leftP = 100. * numLeft / numPlus
797             # we have a region that passes all criteria
798             region = Region.DirectionalRegion(regionStart, regionStop,
799                                               factor, index, chromosome, sumAll,
800                                               foldRatio, multiP, plusP, leftP,
801                                               peakDescription, shift)
802
803         else:
804             raise RegionDirectionError
805     else:
806         # we have a region, but didn't check for directionality
807         region = Region.Region(regionStart, regionStop, factor, index, chromosome,
808                                sumAll, foldRatio, multiP, peakDescription, shift)
809
810     return region
811
812
813 def setMultireadPercentage(region, hitRDS, hitRDSsize, currentTotalWeight, currentUniqueCount, chromosome, lastReadPos, normalize, doTrim):
814     if doTrim:
815         sumMulti = hitRDS.getMultiCount(chromosome, region.start, lastReadPos)
816     else:
817         sumMulti = currentTotalWeight - currentUniqueCount
818
819     # normalize to RPM
820     if normalize:
821         sumMulti /= hitRDSsize
822
823     try:
824         multiP = 100. * (sumMulti / region.numReads)
825     except ZeroDivisionError:
826         return
827
828     region.multiP = multiP
829
830
831 def regionAndPeakPass(region, minHits, minRatio, regionLength, readlen, peakScore, minPeak, minPlusRatio, maxPlusRatio, plusRatio):
832     regionPasses = False
833     if regionPassesCriteria(region.numReads, minHits, region.foldRatio, minRatio, regionLength, readlen):
834         if peakScore >= minPeak and minPlusRatio <= plusRatio <= maxPlusRatio:
835             regionPasses = True
836
837     return regionPasses
838
839
840 def updateRegion(region,
841                  doDirectionality, leftPlusRatio, numLeft,
842                  numPlus, plusRatio):
843
844     if doDirectionality:
845         if leftPlusRatio < numLeft / numPlus:
846             region.plusP = plusRatio * 100.
847             region.leftP = 100. * numLeft / numPlus
848         else:
849             raise RegionDirectionError
850
851
852 def writeNoRevBackgroundResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
853                                 allregions, header):
854
855     writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
856                            allregions, header, backregions=[], pValueType="self")
857
858
859 def writeChromosomeResults(regionFinder, outregions, outfile, doPvalue, reportshift, shiftDict,
860                            allregions, header, backregions=[], pValueType="none"):
861
862     print regionFinder.statistics["mIndex"], regionFinder.statistics["mTotal"]
863     if doPvalue:
864         if pValueType == "self":
865             poissonmean = calculatePoissonMean(allregions)
866         else:
867             poissonmean = calculatePoissonMean(backregions)
868
869     print header
870     writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=regionFinder.shiftValue, reportshift=reportshift, shiftDict=shiftDict)
871
872
873 def calculatePoissonMean(dataList):
874     dataList.sort()
875     listSize = float(len(dataList))
876     try:
877         poissonmean = sum(dataList) / listSize
878     except ZeroDivisionError:
879         poissonmean = 0
880
881     print "Poisson n=%d, p=%f" % (listSize, poissonmean)
882
883     return poissonmean
884
885
886 def writeRegions(outregions, outfile, doPvalue, poissonmean, shiftValue=0, reportshift=False, shiftDict={}):
887     for region in outregions:
888         if shiftValue == "auto" and reportshift:
889             try:
890                 shiftDict[region.shift] += 1
891             except KeyError:
892                 shiftDict[region.shift] = 1
893
894         outline = getRegionString(region, reportshift)
895
896         # iterative poisson from http://stackoverflow.com/questions/280797?sort=newest
897         if doPvalue:
898             sumAll = int(region.numReads)
899             pValue = calculatePValue(sumAll, poissonmean)
900             outline += "\t%1.2g" % pValue
901
902         print outline
903         print >> outfile, outline
904
905
906 def calculatePValue(sum, poissonmean):
907     pValue = math.exp(-poissonmean)
908     #TODO: 798: DeprecationWarning: integer argument expected, got float - for i in xrange(sum)
909     for i in xrange(sum):
910         pValue *= poissonmean
911         pValue /= i+1
912
913     return pValue
914
915
916 def getRegionString(region, reportShift):
917     if reportShift:
918         outline = region.printRegionWithShift()
919     else:
920         outline = region.printRegion()
921
922     return outline
923
924
925 def getFooter(regionFinder, shiftDict, reportshift, doRevBackground):
926     index = regionFinder.statistics["index"]
927     mIndex = regionFinder.statistics["mIndex"]
928     footerLines = ["#stats:\t%.1f RPM in %d regions" % (regionFinder.statistics["total"], index)]
929     if regionFinder.doDirectionality:
930         footerLines.append("#\t\t%d additional regions failed directionality filter" % regionFinder.statistics["failed"])
931
932     if doRevBackground:
933         try:
934             percent = min(100. * (float(mIndex)/index), 100.)
935         except ZeroDivisionError:
936             percent = 0.
937
938         footerLines.append("#%d regions (%.1f RPM) found in background (FDR = %.2f percent)" % (mIndex, regionFinder.statistics["mTotal"], percent))
939
940     if regionFinder.shiftValue == "auto" and reportshift:
941         bestShift = getBestShiftInDict(shiftDict)
942         footerLines.append("#mode of shift values: %d" % bestShift)
943
944     if regionFinder.statistics["badRegionTrim"] > 0:
945         footerLines.append("#%d regions discarded due to trimming problems" % regionFinder.statistics["badRegionTrim"])
946
947     return string.join(footerLines, "\n")
948
949
950 def getBestShiftInDict(shiftDict):
951     return max(shiftDict.iteritems(), key=operator.itemgetter(1))[0]
952
953
954 if __name__ == "__main__":
955     main(sys.argv)