snapshot of 4.0a development. initial git repo commit
[erange.git] / commoncode.py
1 #
2 #  commoncode.py
3 #  ENRAGE
4 #
5
6 import tempfile
7 import shutil
8 import os
9 from os import environ
10 import string
11 import sqlite3 as sqlite
12 from time import strftime
13 from array import array
14 from collections import defaultdict
15
16 commoncodeVersion = 5.5
17 currentRDSversion = 1.1
18
19 if environ.get("CISTEMATIC_TEMP"):
20     cisTemp = environ.get("CISTEMATIC_TEMP")
21 else:
22     cisTemp = "/tmp"
23
24 tempfile.tempdir = cisTemp
25
26
27 def getReverseComplement(base):
28     revComp = {"A": "T",
29                "T": "A",
30                "G": "C",
31                "C": "G",
32                "N": "N"
33         }
34
35     return revComp[base]
36
37
38 def countDuplicatesInList(listToCheck):
39     tally = defaultdict(int)
40     for item in listToCheck:
41         tally[item] += 1
42
43     return tally.items()
44
45
46 def writeLog(logFile, messenger, message):
47     """ create a log file to write a message from a messenger or append to an existing file.
48     """
49     try:
50         logfile = open(logFile)
51     except IOError:
52         logfile = open(logFile, "w")
53     else:
54         logfile = open(logFile, "a")
55
56     logfile.writelines("%s: [%s] %s\n" % (strftime("%Y-%m-%d %H:%M:%S"), messenger, message))
57     logfile.close()
58
59
60 def getMergedRegions(regionfilename, maxDist=1000, minHits=0, verbose=False, keepLabel=False,
61                      fullChrom = False, chromField=1, scoreField=4, pad=0, compact=False,
62                      doMerge=True, keepPeak=False, returnTop=0):
63
64     """ returns a list of merged overlapping regions; 
65     can optionally filter regions that have a scoreField fewer than minHits.
66     Can also optionally return the label of each region, as well as the
67     peak, if supplied (peakPos and peakHeight should be the last 2 fields).
68     Can return the top regions based on score if higher than minHits.
69     """
70     infile = open(regionfilename)
71     lines = infile.readlines()
72     regions = getMergedRegionsFromList(lines, maxDist, minHits, verbose, keepLabel,
73                                        fullChrom, chromField, scoreField, pad, compact,
74                                        doMerge, keepPeak, returnTop)
75
76     infile.close()
77
78     return regions
79
80
81 def getMergedRegionsFromList(regionList, maxDist=1000, minHits=0, verbose=False, keepLabel=False,
82                      fullChrom = False, chromField=1, scoreField=4, pad=0, compact=False,
83                      doMerge=True, keepPeak=False, returnTop=0):
84     """ returns a list of merged overlapping regions; 
85     can optionally filter regions that have a scoreField fewer than minHits.
86     Can also optionally return the label of each region, as well as the
87     peak, if supplied (peakPos and peakHeight should be the last 2 fields).
88     Can return the top regions based on score if higher than minHits.
89     """
90     regions = {}
91     hasPvalue = 0
92     hasShift = 0
93     if 0 < returnTop < len(regionList):
94         scores = []
95         for regionEntry in regionList:
96             if regionEntry[0] == "#":
97                 if "pvalue" in regionEntry:
98                     hasPvalue = 1
99
100                 if "readShift" in regionEntry:
101                     hasShift = 1
102
103                 continue
104
105             fields = regionEntry.strip().split("\t")
106             hits = float(fields[scoreField].strip())
107             scores.append(hits)
108
109         scores.sort()
110         returnTop = -1 * returnTop 
111         minScore = scores[returnTop]
112         if minScore > minHits:
113             minHits = minScore
114
115     mergeCount = 0
116     chromField = int(chromField)
117     count = 0
118     #TODO: Current algorithm processes input file line by line and compares with prior lines.  Problem is it
119     #      exits at the first merge.  This is not a problem when the input is sorted by start position, but in
120     #      the case of 3 regions ABC that are in the input file as ACB as it goes now when processing C there
121     #      will be no merge with A as B is needed to bridge the two.  When it comes time to process B it will
122     #      be merged with A but that will exit the loop and the merge with C will be missed.
123     for regionEntry in regionList:
124         if regionEntry[0] == "#":
125             if "pvalue" in regionEntry:
126                 hasPvalue = 1
127
128             if "readShift" in regionEntry:
129                 hasShift = 1
130
131             continue
132
133         fields = regionEntry.strip().split("\t")
134         if minHits >= 0:
135             try:
136                 hits = float(fields[scoreField].strip())
137             except (IndexError, ValueError):
138                 continue
139
140             if hits < minHits:
141                 continue
142
143         if compact:
144             (chrom, pos) = fields[chromField].split(":")
145             (front, back) = pos.split("-")
146             start = int(front)
147             stop = int(back)
148         elif chromField > 1:
149             label = string.join(fields[:chromField],"\t")
150             chrom = fields[chromField]
151             start = int(fields[chromField + 1]) - pad
152             stop = int(fields[chromField + 2]) + pad
153         else:
154             label = fields[0]
155             chrom = fields[1]
156             start = int(fields[2]) - pad
157             stop = int(fields[3]) + pad
158
159         if not fullChrom:
160             chrom = chrom[3:]
161
162         length = abs(stop - start)
163         if keepPeak:
164             peakPos = int(fields[-2 - hasPvalue - hasShift])
165             peakHeight = float(fields[-1 - hasPvalue - hasShift])
166
167         if chrom not in regions:
168             regions[chrom] = []
169
170         merged = False
171
172         if doMerge and len(regions[chrom]) > 0:
173             for index in range(len(regions[chrom])):
174                 if keepLabel and keepPeak:
175                     (rlabel, rstart, rstop, rlen, rpeakPos, rpeakHeight) = regions[chrom][index]
176                 elif keepLabel:
177                     (rlabel, rstart, rstop, rlen) = regions[chrom][index]
178                 elif keepPeak:
179                     (rstart, rstop, rlen, rpeakPos, rpeakHeight) = regions[chrom][index]
180                 else:
181                     (rstart, rstop, rlen) = regions[chrom][index]
182
183                 if regionsOverlap(start, stop, rstart, rstop) or regionsAreWithinDistance(start, stop, rstart, rstop, maxDist):
184                     if start < rstart:
185                         rstart = start
186
187                     if rstop < stop:
188                         rstop = stop
189
190                     rlen = abs(rstop - rstart)
191                     if keepPeak:
192                         if peakHeight > rpeakHeight:
193                             rpeakHeight = peakHeight
194                             rpeakPos = peakPos
195
196                     if keepLabel and keepPeak:
197                         regions[chrom][index] = (label, rstart, rstop, rlen, rpeakPos, rpeakHeight)
198                     elif keepLabel:
199                         regions[chrom][index] = (label, rstart, rstop, rlen)
200                     elif keepPeak:
201                         regions[chrom][index] = (rstart, rstop, rlen, rpeakPos, rpeakHeight)
202                     else:
203                         regions[chrom][index] = (rstart, rstop, rlen)
204
205                     mergeCount += 1
206                     merged = True
207                     break
208
209         if not merged:
210             if keepLabel and keepPeak:
211                 regions[chrom].append((label, start, stop, length, peakPos, peakHeight))
212             elif keepLabel:
213                 regions[chrom].append((label, start, stop, length))
214             elif keepPeak:
215                 regions[chrom].append((start, stop, length, peakPos, peakHeight))
216             else:
217                 regions[chrom].append((start, stop, length))
218
219             count += 1
220
221         if verbose and (count % 100000 == 0):
222             print count
223
224     regionCount = 0
225     for chrom in regions:
226         regionCount += len(regions[chrom])
227         if keepLabel:
228             regions[chrom].sort(cmp=lambda x,y:cmp(x[1], y[1]))
229         else:
230             regions[chrom].sort()
231
232     if verbose:
233         print "merged %d times" % mergeCount
234         print "returning %d regions" % regionCount
235
236     return regions
237
238
239 def regionsOverlap(start, stop, rstart, rstop):
240     if start > stop:
241         (start, stop) = (stop, start)
242
243     if rstart > rstop:
244         (rstart, rstop) = (rstop, rstart)
245
246     return (rstart <= start <= rstop) or (rstart <= stop <= rstop) or (start <= rstart <= stop) or (start <= rstop <= stop)
247
248
249 def regionsAreWithinDistance(start, stop, rstart, rstop, maxDist):
250     if start > stop:
251         (start, stop) = (stop, start)
252
253     if rstart > rstop:
254         (rstart, rstop) = (rstop, rstart)
255
256     return (abs(rstart-stop) <= maxDist) or (abs(rstop-start) <= maxDist)
257
258
259 def findPeak(hitList, start, length, readlen=25, doWeight=False, leftPlus=False,
260              shift=0, returnShift=False, maxshift=75):
261     """ find the peak in a list of reads (hitlist) in a region
262     of a given length and absolute start point. returns a
263     list of peaks, the number of hits, a triangular-smoothed
264     version of hitlist, and the number of reads that are
265     forward (plus) sense.
266     If doWeight is True, weight the reads accordingly.
267     If leftPlus is True, return the number of plus reads left of
268     the peak, taken to be the first TopPos position.
269     """
270
271     seqArray = array("f", [0.] * length)
272     smoothArray = array("f", [0.] * length)
273     numHits = 0.
274     numPlus = 0.
275     regionArray = []
276     if shift == "auto":
277         shift = getBestShiftForRegion(hitList, start, length, doWeight, maxshift)
278
279     # once we have the best shift, compute seqArray
280     for read in hitList:
281         currentpos = read[0] - start
282         if read[1] == "+":
283             currentpos += shift
284         else:
285             currentpos -= shift
286
287         if (currentpos <  1 - readlen) or (currentpos >= length):
288             continue
289
290         hitIndex = 0
291         if doWeight:
292             weight = read[2]
293         else:
294             weight = 1.0
295
296         numHits += weight
297         if leftPlus:
298             regionArray.append(read)
299
300         while currentpos < 0:
301             hitIndex += 1
302             currentpos += 1
303
304         while hitIndex < readlen and  currentpos < length:
305             seqArray[currentpos] += weight
306             hitIndex += 1
307             currentpos += 1
308
309         if read[1] == "+":
310             numPlus += weight
311
312     # implementing a triangular smooth
313     for pos in range(2,length -2):
314         smoothArray[pos] = (seqArray[pos -2] + 2 * seqArray[pos - 1] + 3 * seqArray[pos] + 2 * seqArray[pos + 1] + seqArray[pos + 2]) / 9.0
315
316     topNucleotide = 0
317     topPos = []
318     for currentpos in xrange(length):
319         if topNucleotide < smoothArray[currentpos]:
320             topNucleotide = smoothArray[currentpos]
321             topPos = [currentpos]
322         elif topNucleotide  == smoothArray[currentpos]:
323             topPos.append(currentpos)
324
325     if leftPlus:
326         numLeftPlus = 0
327         maxPos = topPos[0]
328         for read in regionArray:
329             if doWeight:
330                 weight = read[2]
331             else:
332                 weight = 1.0
333
334             currentPos = read[0] - start
335             if currentPos <= maxPos and read[1] == "+":
336                 numLeftPlus += weight
337
338         if returnShift:
339             return (topPos, numHits, smoothArray, numPlus, numLeftPlus, shift)
340         else:
341             return (topPos, numHits, smoothArray, numPlus, numLeftPlus)
342     else:
343         if returnShift:
344             return (topPos, numHits, smoothArray, numPlus, shift)
345         else:
346             return (topPos, numHits, smoothArray, numPlus)
347
348
349 def getBestShiftForRegion(hitList, start, length, doWeight=False, maxShift=75):
350     bestShift = 0
351     lowestScore = 20000000000
352     for testShift in xrange(maxShift + 1):
353         shiftArray = array("f", [0.] * length)
354         for read in hitList:
355             currentpos = read[0] - start
356             if read[1] == "+":
357                 currentpos += testShift
358             else:
359                 currentpos -= testShift
360
361             if (currentpos < 1) or (currentpos >= length):
362                 continue
363
364             if doWeight:
365                 weight = read[2]
366             else:
367                 weight = 1.0
368
369             if read[1] == "+":
370                 shiftArray[currentpos] += weight
371             else:
372                 shiftArray[currentpos] -= weight
373
374         currentScore = 0
375         for score in shiftArray:
376             currentScore += abs(score)
377
378         print currentScore
379         if currentScore < lowestScore:
380             bestShift = testShift
381             lowestScore = currentScore
382
383     return bestShift
384
385
386 def getFeaturesByChromDict(genomeObject, additionalRegionsDict={}, ignorePseudo=False,
387                            restrictList=[], regionComplement=False, maxStop=250000000):
388     """ return a dictionary of cistematic gene features. Requires
389     cistematic, obviously. Can filter-out pseudogenes. Will use
390     additional regions dict to supplement gene models, if available.
391     Can restrict output to a list of GIDs.
392     If regionComplement is set to true, returns the regions *outside* of the
393     calculated boundaries, which is useful for retrieving intronic and
394     intergenic regions. maxStop is simply used to define the uppermost
395     boundary of the complement region.
396     """ 
397     featuresDict = genomeObject.getallGeneFeatures()
398     restrictGID = False
399     if len(restrictList) > 0:
400         restrictGID = True
401
402     if len(additionalRegionsDict) > 0:
403         sortList = []
404         for chrom in additionalRegionsDict:
405             for (label, start, stop, length) in additionalRegionsDict[chrom]:
406                 if label not in sortList:
407                     sortList.append(label)
408
409                 if label not in featuresDict:
410                     featuresDict[label] = []
411                     sense = "+"
412                 else:
413                     sense = featuresDict[label][0][-1]
414
415                 featuresDict[label].append(("custom", chrom, start, stop, sense))
416
417         for gid in sortList:
418             featuresDict[gid].sort(cmp=lambda x,y:cmp(x[2], y[2]))
419
420     featuresByChromDict = {}
421     for gid in featuresDict:
422         if restrictGID and gid not in restrictList:
423             continue
424
425         featureList = featuresDict[gid]
426         newFeatureList = []
427         isPseudo = False
428         for (ftype, chrom, start, stop, sense) in featureList:
429             if ftype == "PSEUDO":
430                 isPseudo = True
431
432             if (start, stop, ftype) not in newFeatureList:
433                 notContained = True
434                 containedList = []
435                 for (fstart, fstop, ftype2) in newFeatureList:
436                     if start >= fstart and stop <= fstop:
437                         notContained = False
438
439                     if start < fstart and stop > fstop:
440                         containedList.append((fstart, fstop))
441
442                 if len(containedList) > 0:
443                     newFList = []
444                     notContained = True
445                     for (fstart, fstop, ftype2) in newFeatureList:
446                         if (fstart, fstop) not in containedList:
447                             newFList.append((fstart, fstop, ftype2))
448                             if start >= fstart and stop <= fstop:
449                                 notContained = False
450
451                     newFeatureList = newFList
452                 if notContained:
453                     newFeatureList.append((start, stop, ftype))
454
455         if ignorePseudo and isPseudo:
456             continue
457
458         if chrom not in featuresByChromDict:
459             featuresByChromDict[chrom] = []
460
461         for (start, stop, ftype) in newFeatureList:
462             featuresByChromDict[chrom].append((start, stop, gid, sense, ftype))
463
464     for chrom in featuresByChromDict:
465         featuresByChromDict[chrom].sort()
466
467     if regionComplement:
468         complementByChromDict = {}
469         complementIndex = 0
470         for chrom in featuresByChromDict:
471             complementByChromDict[chrom] = []
472             listLength = len(featuresByChromDict[chrom])
473             if listLength > 0:
474                 currentStart = 0
475                 for index in range(listLength):
476                     currentStop = featuresByChromDict[chrom][index][0]
477                     complementIndex += 1
478                     if currentStart < currentStop:
479                         complementByChromDict[chrom].append((currentStart, currentStop, "nonExon%d" % complementIndex, "F", "nonExon"))
480
481                     currentStart = featuresByChromDict[chrom][index][1]
482
483                 currentStop = maxStop
484                 complementByChromDict[chrom].append((currentStart, currentStop, "nonExon%d" % complementIndex, "F", "nonExon"))
485
486         return (featuresByChromDict, complementByChromDict)
487     else:
488         return featuresByChromDict
489
490
491 def getLocusByChromDict(genomeObject, upstream=0, downstream=0, useCDS=True,
492                         additionalRegionsDict={}, ignorePseudo=False, upstreamSpanTSS=False,
493                         lengthCDS=0, keepSense=False, adjustToNeighbor=True):
494     """ return a dictionary of gene loci. Can be used to retrieve additional
495     sequence upstream or downstream of gene, up to the next gene. Requires
496     cistematic, obviously.
497     Can filter-out pseudogenes and use additional regions outside of existing
498     gene models. Use upstreamSpanTSS to overlap half of the upstream region
499     over the TSS.
500     If lengthCDS > 0 bp, e.g. X, return only the starting X bp from CDS. If
501     lengthCDS < 0bp, return only the last X bp from CDS.
502     """ 
503     locusByChromDict = {}
504     if upstream == 0 and downstream == 0 and not useCDS:
505         print "getLocusByChromDict: asked for no sequence - returning empty dict"
506         return locusByChromDict
507     elif upstream > 0 and downstream > 0 and not useCDS:
508         print "getLocusByChromDict: asked for only upstream and downstream - returning empty dict"
509         return locusByChromDict
510     elif lengthCDS != 0 and not useCDS:
511         print "getLocusByChromDict: asked for partial CDS but not useCDS - returning empty dict"
512         return locusByChromDict
513     elif upstreamSpanTSS and lengthCDS != 0:
514         print "getLocusByChromDict: asked for TSS spanning and partial CDS - returning empty dict"
515         return locusByChromDict
516     elif lengthCDS > 0 and downstream > 0:
517         print "getLocusByChromDict: asked for discontinuous partial CDS from start and downstream - returning empty dict"
518         return locusByChromDict
519     elif lengthCDS < 0 and upstream > 0:
520         print "getLocusByChromDict: asked for discontinuous partial CDS from stop and upstream - returning empty dict"
521         return locusByChromDict
522
523     genome = genomeObject.genome
524     featuresDict = genomeObject.getallGeneFeatures()
525     if len(additionalRegionsDict) > 0:
526         sortList = []
527         for chrom in additionalRegionsDict:
528             for (label, start, stop, length) in additionalRegionsDict[chrom]:
529                 if label not in sortList:
530                     sortList.append(label)
531
532                 if label not in featuresDict:
533                     featuresDict[label] = []
534                     sense = "+"
535                 else:
536                     sense = featuresDict[label][0][-1]
537
538                 featuresDict[label].append(("custom", chrom, start, stop, sense))
539
540         for gid in sortList:
541             featuresDict[gid].sort(cmp=lambda x,y:cmp(x[2], y[2]))
542
543     for gid in featuresDict:
544         featureList = featuresDict[gid]
545         newFeatureList = []
546         for (ftype, chrom, start, stop, sense) in featureList:
547             newFeatureList.append((start, stop))
548
549         if ignorePseudo and ftype == "PSEUDO":
550             continue
551
552         newFeatureList.sort()
553
554         sense = featureList[0][-1]
555         gstart = newFeatureList[0][0]
556         gstop = newFeatureList[-1][1]
557         glen = abs(gstart - gstop)
558         if sense == "F":
559             if not useCDS and upstream > 0:
560                 if upstreamSpanTSS:
561                     if gstop > (gstart + upstream / 2):
562                         gstop = gstart + upstream / 2
563                 else:
564                     gstop = gstart
565             elif not useCDS and downstream > 0:
566                 gstart = gstop
567
568             if upstream > 0:
569                 if upstreamSpanTSS:
570                     distance = upstream / 2
571                 else:
572                     distance = upstream
573
574                 if adjustToNeighbor:
575                     nextGene = genomeObject.leftGeneDistance((genome, gid), distance * 2)
576                     if nextGene < distance * 2:
577                         distance = nextGene / 2
578
579                 if distance < 1:
580                     distance = 1
581
582                 gstart -= distance
583
584             if downstream > 0:
585                 distance = downstream
586                 if adjustToNeighbor:
587                     nextGene = genomeObject.rightGeneDistance((genome, gid), downstream * 2)
588                     if nextGene < downstream * 2:
589                         distance = nextGene / 2
590
591                 if distance < 1:
592                     distance = 1
593
594                 gstop += distance
595
596             if lengthCDS > 0:
597                 if lengthCDS < glen:
598                     gstop = newFeatureList[0][0] + lengthCDS
599
600             if lengthCDS < 0:
601                 if abs(lengthCDS) < glen:
602                     gstart = newFeatureList[-1][1] + lengthCDS
603         else:
604             if not useCDS and upstream > 0:
605                 if upstreamSpanTSS:
606                     if gstart < (gstop - upstream / 2):
607                         gstart = gstop - upstream / 2
608                 else:
609                     gstart = gstop
610             elif not useCDS and downstream > 0:
611                     gstop = gstart
612
613             if upstream > 0:
614                 if upstreamSpanTSS:
615                     distance = upstream /2
616                 else:
617                     distance = upstream
618
619                 if adjustToNeighbor:
620                     nextGene = genomeObject.rightGeneDistance((genome, gid), distance * 2)
621                     if nextGene < distance * 2:
622                         distance = nextGene / 2
623
624                 if distance < 1:
625                     distance = 1
626
627                 gstop += distance
628
629             if downstream > 0:
630                 distance = downstream
631                 if adjustToNeighbor:
632                     nextGene = genomeObject.leftGeneDistance((genome, gid), downstream * 2)
633                     if nextGene < downstream * 2:
634                         distance = nextGene / 2
635
636                 if distance < 1:
637                     distance = 1
638
639                 gstart -= distance
640
641             if lengthCDS > 0:
642                 if lengthCDS < glen:
643                     gstart = newFeatureList[-1][-1] - lengthCDS
644
645             if lengthCDS < 0:
646                 if abs(lengthCDS) < glen:
647                     gstop = newFeatureList[0][0] - lengthCDS
648
649         glen = abs(gstop - gstart)
650         if chrom not in locusByChromDict:
651             locusByChromDict[chrom] = []
652
653         if keepSense:
654             locusByChromDict[chrom].append((gstart, gstop, gid, glen, sense))
655         else:
656             locusByChromDict[chrom].append((gstart, gstop, gid, glen))
657
658     for chrom in locusByChromDict:
659         locusByChromDict[chrom].sort()
660
661     return locusByChromDict
662
663
664 def computeRegionBins(regionsByChromDict, hitDict, bins, readlen, regionList=[],
665                       normalizedTag=1., defaultRegionFormat=True, fixedFirstBin=-1,
666                       binLength=-1):
667     """ returns 2 dictionaries of bin counts and region lengths, given a dictionary of predefined regions,
668         a dictionary of reads, a number of bins, the length of reads, and optionally a list of regions
669         or a different weight / tag.
670     """
671     index = 0
672     regionsBins = {}
673     regionsLen = {}
674
675     if defaultRegionFormat:
676         regionIDField = 0
677         startField = 1
678         stopField = 2
679         lengthField = 3
680     else:
681         startField = 0
682         stopField = 1
683         regionIDField = 2
684         lengthField = 3
685
686     senseField = 4
687
688     print "entering computeRegionBins"
689     if len(regionList) > 0:
690         for readID in regionList:
691             regionsBins[readID] = [0.] * bins
692     else:
693         for chrom in regionsByChromDict:
694             for regionTuple in regionsByChromDict[chrom]:
695                 regionID = regionTuple[regionIDField]
696                 regionsBins[regionID] = [0.] * bins
697
698     for chrom in hitDict:
699         if chrom not in regionsByChromDict:
700             continue
701
702         for regionTuple in regionsByChromDict[chrom]:
703             regionID = regionTuple[regionIDField]
704             regionsLen[regionID] = regionTuple[lengthField]
705
706         print "%s\n" % chrom
707         startRegion = 0
708         for (tagStart, sense, weight) in hitDict[chrom]:
709             index += 1
710             if index % 100000 == 0:
711                 print "read %d " % index,
712
713             stopPoint = tagStart + readlen
714             if startRegion < 0:
715                 startRegion = 0
716
717             for regionTuple in regionsByChromDict[chrom][startRegion:]:
718                 start = regionTuple[startField]
719                 stop = regionTuple[stopField]
720                 regionID = regionTuple[regionIDField]
721                 rlen = regionTuple[lengthField]
722                 try:
723                     rsense = regionTuple[senseField]
724                 except:
725                     rsense = "F"
726
727                 if tagStart > stop:
728                     startRegion += 1
729                     continue
730
731                 if start > stopPoint:
732                     startRegion -= 10
733                     break
734
735                 if start <= tagStart <= stop:
736                     if binLength < 1:
737                         regionBinLength = rlen / bins
738                     else:
739                         regionBinLength = binLength
740
741                     startdist = tagStart - start
742                     if rsense == "F":
743                         # we are relying on python's integer division quirk
744                         binID = startdist / regionBinLength
745                         if (fixedFirstBin > 0) and (startdist < fixedFirstBin):
746                             binID = 0
747                         elif fixedFirstBin > 0:
748                             binID = 1
749
750                         if binID >= bins:
751                             binID = bins - 1
752
753                         try:
754                             regionsBins[regionID][binID] += normalizedTag * weight
755                         except KeyError:
756                             print "%s %s" % (regionID, str(binID))
757                     else:
758                         rdist = rlen - startdist
759                         binID = rdist / regionBinLength
760                         if (fixedFirstBin > 0) and (rdist < fixedFirstBin):
761                             binID = 0
762                         elif fixedFirstBin > 0:
763                             binID = 1
764
765                         if binID >= bins:
766                             binID = bins - 1
767
768                         try:
769                             regionsBins[regionID][binID] += normalizedTag * weight
770                         except KeyError:
771                             print "%s %s" % (regionID, str(binID))
772
773                     stopPoint = stop
774
775     return (regionsBins, regionsLen)
776
777
778 # TODO: The readDataset class is going to be replaced by Erange.ReadDataset but this will
779 # require going through all the code to make the changes needed.  Major project for another
780 # day, but it really needs to be done
781 class readDataset:
782     """ Class for storing reads from experiments. Assumes that custom scripts
783     will translate incoming data into a format that can be inserted into the
784     class using the insert* methods. Default class subtype ('DNA') includes
785     tables for unique and multireads, whereas 'RNA' subtype also includes a
786     splices table.
787     """
788
789     def __init__(self, datafile, initialize=False, datasetType='', verbose=False, 
790                  cache=False, reportCount=True):
791         """ creates an rds datafile if initialize is set to true, otherwise
792         will append to existing tables. datasetType can be either 'DNA' or 'RNA'.
793         """
794         self.dbcon = ""
795         self.memcon = ""
796         self.dataType = ""
797         self.rdsVersion = "1.1"
798         self.memBacked = False
799         self.memChrom = ""
800         self.memCursor = ""
801         self.cachedDBFile = ""
802
803         if cache:
804             if verbose:
805                 print "caching ...."
806
807             self.cacheDB(datafile)
808             dbfile = self.cachedDBFile
809         else:
810             dbfile = datafile
811
812         self.dbcon = sqlite.connect(dbfile)
813         self.dbcon.row_factory = sqlite.Row
814         self.dbcon.execute("PRAGMA temp_store = MEMORY")
815         if initialize:
816             if datasetType == "":
817                 self.dataType = "DNA"
818             else:
819                 self.dataType = datasetType
820
821             self.initializeTables(self.dbcon)
822         else:
823             metadata = self.getMetadata("dataType")
824             self.dataType = metadata["dataType"]
825
826         try:
827             metadata = self.getMetadata("rdsVersion")
828             self.rdsVersion = metadata["rdsVersion"]
829         except:
830             try:
831                 self.insertMetadata([("rdsVersion", currentRDSversion)])
832             except:
833                 print "could not add rdsVersion - read-only ?"
834                 self.rdsVersion = "pre-1.0"
835
836         if verbose:
837             if initialize:
838                 print "INITIALIZED dataset %s" % datafile
839             else:
840                 print "dataset %s" % datafile
841
842             metadata = self.getMetadata()
843             print "metadata:"
844             pnameList = metadata.keys()
845             pnameList.sort()
846             for pname in pnameList:
847                 print "\t" + pname + "\t" + metadata[pname]
848
849             if reportCount:
850                 ucount = self.getUniqsCount()
851                 mcount = self.getMultiCount()
852                 if self.dataType == "DNA" and not initialize:
853                     try:
854                         print "\n%d unique reads and %d multireads" % (int(ucount), int(mcount))
855                     except:
856                         print "\n%s unique reads and %s multireads" % (ucount, mcount)
857                 elif self.dataType == 'RNA' and not initialize:
858                     scount = self.getSplicesCount()
859                     try:
860                         print "\n%d unique reads, %d spliced reads and %d multireads" % (int(ucount), int(scount), int(mcount))
861                     except:
862                         print "\n%s unique reads, %s spliced reads and %s multireads" % (ucount, scount, mcount)
863
864             print "default cache size is %d pages" % self.getDefaultCacheSize()
865             if self.hasIndex():
866                 print "found index"
867             else:
868                 print "not indexed"
869
870
871     def __len__(self):
872         """ return the number of usable reads in the dataset.
873         """
874         try:
875             total = self.getUniqsCount()
876         except:
877             total = 0
878
879         try:
880             total += self.getMultiCount()
881         except:
882             pass
883
884         if self.dataType == "RNA":
885             try:
886                 total += self.getSplicesCount()
887             except:
888                 pass
889
890         try:
891             total = int(total)
892         except:
893             total = 0
894
895         return total
896
897
898     def __del__(self):
899         """ cleanup copy in local cache, if present.
900         """
901         if self.cachedDBFile != "":
902             self.uncacheDB()
903
904
905     def cacheDB(self, filename):
906         """ copy geneinfoDB to a local cache.
907         """
908         self.cachedDBFile = tempfile.mktemp() + ".db"
909         shutil.copyfile(filename, self.cachedDBFile)
910
911
912     def saveCacheDB(self, filename):
913         """ copy geneinfoDB to a local cache.
914         """
915         shutil.copyfile(self.cachedDBFile, filename)
916
917
918     def uncacheDB(self):
919         """ delete geneinfoDB from local cache.
920         """
921         global cachedDBFile
922         if self.cachedDBFile != "":
923             try:
924                 os.remove(self.cachedDBFile)
925             except:
926                 print "could not delete %s" % self.cachedDBFile
927
928             self.cachedDB = ""
929
930
931     def attachDB(self, filename, asname):
932         """ attach another database file to the readDataset.
933         """
934         stmt = "attach '%s' as %s" % (filename, asname)
935         self.execute(stmt)
936
937
938     def detachDB(self, asname):
939         """ detach a database file to the readDataset.
940         """
941         stmt = "detach %s" % (asname)
942         self.execute(stmt)
943
944
945     def importFromDB(self, asname, table, ascolumns="*", destcolumns="", flagged=""):
946         """ import into current RDS the table (with columns destcolumns,
947             with default all columns) from the database file asname,
948             using the column specification of ascolumns (default all).
949         """
950         stmt = "insert into %s %s select %s from %s.%s" % (table, destcolumns, ascolumns, asname, table)
951         if flagged != "":
952             stmt += " where flag = '%s' " % flagged
953
954         self.execute(stmt, forceCommit=True)
955
956
957     def getTables(self, asname=""):
958         """ get a list of table names in a particular database file.
959         """
960         resultList = []
961
962         if self.memBacked:
963             sql = self.memcon.cursor()
964         else:
965             sql = self.dbcon.cursor()
966
967         if asname != "":
968             asname += "."
969
970         stmt = "select name from %ssqlite_master where type='table'" % asname
971         sql.execute(stmt)
972         results = sql.fetchall()
973
974         for row in results:
975             resultList.append(row["name"])
976
977         return resultList
978
979
980     def hasIndex(self):
981         """ check whether the RDS file has at least one index.
982         """
983         stmt = "select count(*) from sqlite_master where type='index'"
984         count = int(self.execute(stmt, returnResults=True)[0][0])
985         if count > 0:
986             return True
987
988         return False
989
990
991     def initializeTables(self, acon, cache=100000):
992         """ creates table schema in database connection acon, which is
993         typically a database file or an in-memory database.
994         """
995         acon.execute("PRAGMA DEFAULT_CACHE_SIZE = %d" % cache)
996         acon.execute("create table metadata (name varchar, value varchar)")
997         acon.execute("insert into metadata values('dataType','%s')" % self.dataType)
998         acon.execute("create table uniqs (ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, start int, stop int, sense varchar, weight real, flag varchar, mismatch varchar)")
999         acon.execute("create table multi (ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, start int, stop int, sense varchar, weight real, flag varchar, mismatch varchar)")
1000         if self.dataType == "RNA":
1001             acon.execute("create table splices (ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, startL int, stopL int, startR int, stopR int, sense varchar, weight real, flag varchar, mismatch varchar)")
1002
1003         acon.commit()
1004
1005
1006     def getFileCursor(self):
1007         """ returns a cursor to file database for low-level (SQL)
1008         access to the data.
1009         """
1010         return self.dbcon.cursor()
1011
1012
1013     def getMemCursor(self):
1014         """ returns a cursor to memory database for low-level (SQL)
1015         access to the data.
1016         """
1017         return self.memcon.cursor()
1018
1019
1020     def getMetadata(self, valueName=""):
1021         """ returns a dictionary of metadata.
1022         """
1023         whereClause = ""
1024         resultsDict = {}
1025
1026         if valueName != "":
1027             whereClause = " where name = '%s' " % valueName
1028
1029         if self.memBacked:
1030             sql = self.memcon.cursor()
1031         else:
1032             sql = self.dbcon.cursor()
1033
1034         sql.execute("select name, value from metadata" + whereClause)
1035         results = sql.fetchall()
1036
1037         for row in results:
1038             pname = row["name"]
1039             pvalue = row["value"]
1040             if pname not in resultsDict:
1041                 resultsDict[pname] = pvalue
1042             else:
1043                 trying = True
1044                 index = 2
1045                 while trying:
1046                     newName = pname + ":" + str(index)
1047                     if newName not in resultsDict:
1048                         resultsDict[newName] = pvalue
1049                         trying = False
1050
1051                     index += 1
1052
1053         return resultsDict
1054
1055
1056     def getReadSize(self):
1057         """ returns readsize if defined in metadata.
1058         """
1059         metadata = self.getMetadata()
1060         if "readsize" not in metadata:
1061             print "no readsize parameter defined - returning 0"
1062             return 0
1063         else:
1064             mysize = metadata["readsize"]
1065             if "import" in mysize:
1066                 mysize = mysize.split()[0]
1067
1068             return int(mysize)
1069
1070
1071     def getDefaultCacheSize(self):
1072         """ returns the default cache size.
1073         """
1074         return int(self.execute("PRAGMA DEFAULT_CACHE_SIZE", returnResults=True)[0][0])
1075
1076
1077     def getChromosomes(self, table="uniqs", fullChrom=True):
1078         """ returns a list of distinct chromosomes in table.
1079         """
1080         statement = "select distinct chrom from %s" % table
1081         if self.memBacked:
1082             sql = self.memcon.cursor()
1083         else:
1084             sql = self.dbcon.cursor()
1085
1086         sql.execute(statement)
1087         results = []
1088         for row in sql:
1089             if fullChrom:
1090                 if row["chrom"] not in results:
1091                     results.append(row["chrom"])
1092             else:
1093                 if  len(row["chrom"][3:].strip()) < 1:
1094                     continue
1095
1096                 if row["chrom"][3:] not in results:
1097                     results.append(row["chrom"][3:])
1098
1099         results.sort()
1100
1101         return results
1102
1103
1104     def getMaxCoordinate(self, chrom, verbose=False, doUniqs=True,
1105                          doMulti=False, doSplices=False):
1106         """ returns the maximum coordinate for reads on a given chromosome.
1107         """
1108         maxCoord = 0
1109         if self.memBacked:
1110             sql = self.memcon.cursor()
1111         else:
1112             sql = self.dbcon.cursor()
1113
1114         if doUniqs:
1115             try:
1116                 sql.execute("select max(start) from uniqs where chrom = '%s'" % chrom)
1117                 maxCoord = int(sql.fetchall()[0][0])
1118             except:
1119                 print "couldn't retrieve coordMax for chromosome %s" % chrom
1120
1121         if doSplices:
1122             sql.execute("select max(startR) from splices where chrom = '%s'" % chrom)
1123             try:
1124                 spliceMax = int(sql.fetchall()[0][0])
1125                 if spliceMax > maxCoord:
1126                     maxCoord = spliceMax
1127             except:
1128                 pass
1129
1130         if doMulti:
1131             sql.execute("select max(start) from multi where chrom = '%s'" % chrom)
1132             try:
1133                 multiMax = int(sql.fetchall()[0][0])
1134                 if multiMax > maxCoord:
1135                     maxCoord = multiMax
1136             except:
1137                 pass
1138
1139         if verbose:
1140             print "%s maxCoord: %d" % (chrom, maxCoord)
1141
1142         return maxCoord
1143
1144
1145     def getReadsDict(self, verbose=False, bothEnds=False, noSense=False, fullChrom=False, chrom="",
1146                      flag="", withWeight=False, withFlag=False, withMismatch=False, withID=False,
1147                      withChrom=False, withPairID=False, doUniqs=True, doMulti=False, findallOptimize=False,
1148                      readIDDict=False, readLike="", start=-1, stop=-1, limit=-1, hasMismatch=False,
1149                      flagLike=False, strand="", entryDict=False, combine5p=False):
1150         """ returns a dictionary of reads in a variety of formats
1151         and which can be restricted by chromosome or custom-flag.
1152         Returns unique reads by default, but can return multireads
1153         with doMulti set to True.
1154         """
1155         whereClause = []
1156         resultsDict = {}
1157
1158         if chrom != "" and chrom != self.memChrom:
1159             whereClause.append("chrom = '%s'" % chrom)
1160
1161         if flag != "":
1162             if flagLike:
1163                 flagLikeClause = string.join(['flag LIKE "%', flag, '%"'], "")
1164                 whereClause.append(flagLikeClause)
1165             else:
1166                 whereClause.append("flag = '%s'" % flag)
1167
1168         if start > -1:
1169             whereClause.append("start > %d" % start)
1170
1171         if stop > -1:
1172             whereClause.append("stop < %d" % stop)
1173
1174         if len(readLike) > 0:
1175             readIDClause = string.join(["readID LIKE  '", readLike, "%'"], "")
1176             whereClause.append(readIDClause)
1177
1178         if hasMismatch:
1179             whereClause.append("mismatch != ''")
1180
1181         if strand in ["+", "-"]:
1182             whereClause.append("sense = '%s'" % strand)
1183
1184         if len(whereClause) > 0:
1185             whereStatement = string.join(whereClause, " and ")
1186             whereQuery = "where %s" % whereStatement
1187         else:
1188             whereQuery = ""
1189
1190         groupBy = []
1191         if findallOptimize:
1192             selectClause = ["select start, sense, sum(weight)"]
1193             groupBy = ["GROUP BY start, sense"]
1194         else:
1195             selectClause = ["select ID, chrom, start, readID"]
1196             if bothEnds:
1197                 selectClause.append("stop")
1198
1199             if not noSense:
1200                 selectClause.append("sense")
1201
1202             if withWeight:
1203                 selectClause.append("weight")
1204
1205             if withFlag:
1206                 selectClause.append("flag")
1207
1208             if withMismatch:
1209                 selectClause.append("mismatch")
1210
1211         if limit > 0 and not combine5p:
1212             groupBy.append("LIMIT %d" % limit)
1213
1214         selectQuery = string.join(selectClause, ",")
1215         groupQuery = string.join(groupBy)
1216         if doUniqs:
1217             stmt = [selectQuery, "from uniqs", whereQuery, groupQuery]
1218             if doMulti:
1219                 stmt.append("UNION ALL")
1220                 stmt.append(selectQuery)
1221                 stmt.append("from multi")
1222                 stmt.append(whereQuery)
1223                 stmt.append(groupQuery)
1224         else:
1225             stmt = [selectQuery, "from multi", whereQuery]
1226
1227         if combine5p:
1228             if findallOptimize:
1229                 selectQuery = "select start, sense, weight, chrom"
1230
1231             if doUniqs:
1232                 subSelect = [selectQuery, "from uniqs", whereQuery]
1233                 if doMulti:
1234                     subSelect.append("union all")
1235                     subSelect.append(selectQuery)
1236                     subSelect.append("from multi")
1237                     subSelect.append(whereQuery)
1238             else:
1239                 subSelect = [selectQuery, "from multi", whereQuery]
1240
1241             sqlStmt = string.join(subSelect)
1242             if findallOptimize:
1243                 selectQuery = "select start, sense, sum(weight)"
1244
1245             stmt = [selectQuery, "from (", sqlStmt, ") group by chrom,start having ( count(start) > 1 and count(chrom) > 1) union",
1246                     selectQuery, "from(", sqlStmt, ") group by chrom, start having ( count(start) = 1 and count(chrom) = 1)"]
1247
1248         if findallOptimize:
1249             if self.memBacked:
1250                 self.memcon.row_factory = None
1251                 sql = self.memcon.cursor()
1252             else:
1253                 self.dbcon.row_factory = None
1254                 sql = self.dbcon.cursor()
1255
1256             stmt.append("order by start")
1257         elif readIDDict:
1258             if self.memBacked:
1259                 sql = self.memcon.cursor()
1260             else:
1261                 sql = self.dbcon.cursor()
1262
1263             stmt.append("order by readID, start")
1264         else:
1265             if self.memBacked:
1266                 sql = self.memcon.cursor()
1267             else:
1268                 sql = self.dbcon.cursor()
1269
1270             stmt.append("order by chrom, start")
1271
1272         sqlQuery = string.join(stmt)
1273         sql.execute(sqlQuery)
1274
1275         if findallOptimize:
1276             resultsDict[chrom] = [[int(row[0]), row[1], float(row[2])] for row in sql]
1277             if self.memBacked:
1278                 self.memcon.row_factory = sqlite.Row
1279             else:
1280                 self.dbcon.row_factory = sqlite.Row
1281         else:
1282             currentChrom = ""
1283             currentReadID = ""
1284             pairID = 0
1285             for row in sql:
1286                 readID = row["readID"]
1287                 if fullChrom:
1288                     chrom = row["chrom"]
1289                 else:
1290                     chrom = row["chrom"][3:]
1291
1292                 if not readIDDict and chrom != currentChrom:
1293                     resultsDict[chrom] = []
1294                     currentChrom = chrom
1295                     dictKey = chrom
1296                 elif readIDDict:
1297                     theReadID = readID
1298                     if "::" in readID:
1299                         (theReadID, multiplicity) = readID.split("::")
1300
1301                     if "/" in theReadID and withPairID:
1302                         (theReadID, pairID) = readID.split("/")
1303
1304                     if theReadID != currentReadID:
1305                         resultsDict[theReadID] = []
1306                         currentReadID = theReadID
1307                         dictKey = theReadID
1308
1309                 if entryDict:
1310                     newrow = {"start": int(row["start"])}
1311                     if bothEnds:
1312                         newrow["stop"] = int(row["stop"])
1313
1314                     if not noSense:
1315                         newrow["sense"] = row["sense"]
1316
1317                     if withWeight:
1318                         newrow["weight"] = float(row["weight"])
1319
1320                     if withFlag:
1321                         newrow["flag"] = row["flag"]
1322
1323                     if withMismatch:
1324                         newrow["mismatch"] = row["mismatch"]
1325
1326                     if withID:
1327                         newrow["readID"] = readID
1328
1329                     if withChrom:
1330                         newrow["chrom"] = chrom
1331
1332                     if withPairID:
1333                         newrow["pairID"] = pairID
1334                 else:
1335                     newrow = [int(row["start"])]
1336                     if bothEnds:
1337                         newrow.append(int(row["stop"]))
1338
1339                     if not noSense:
1340                         newrow.append(row["sense"])
1341
1342                     if withWeight:
1343                         newrow.append(float(row["weight"]))
1344
1345                     if withFlag:
1346                         newrow.append(row["flag"])
1347
1348                     if withMismatch:
1349                         newrow.append(row["mismatch"])
1350
1351                     if withID:
1352                         newrow.append(readID)
1353
1354                     if withChrom:
1355                         newrow.append(chrom)
1356
1357                     if withPairID:
1358                         newrow.append(pairID)
1359
1360                 resultsDict[dictKey].append(newrow)
1361
1362         return resultsDict
1363
1364
1365     def getSplicesDict(self, verbose=False, noSense=False, fullChrom=False, chrom="",
1366                        flag="", withWeight=False, withFlag=False, withMismatch=False,
1367                        withID=False, withChrom=False, withPairID=False, readIDDict=False,
1368                        splitRead=False, hasMismatch=False, flagLike=False, start=-1,
1369                        stop=-1, strand="", entryDict=False):
1370         """ returns a dictionary of spliced reads in a variety of
1371         formats and which can be restricted by chromosome or custom-flag.
1372         Returns unique spliced reads for now.
1373         """
1374         whereClause = []
1375         resultsDict = {}
1376
1377         if chrom != "" and chrom != self.memChrom:
1378             whereClause = ["chrom = '%s'" % chrom]
1379
1380         if flag != "":
1381             if flagLike:
1382                 flagLikeClause = string.join(['flag LIKE "%', flag, '%"'], "")
1383                 whereClause.append(flagLikeClause)
1384             else:
1385                 whereClause.append("flag = '%s'" % flag)
1386
1387         if hasMismatch:
1388             whereClause.append("mismatch != ''")
1389
1390         if strand != "":
1391             whereClause.append("sense = '%s'" % strand)
1392
1393         if start > -1:
1394             whereClause.append("startL > %d" % start)
1395
1396         if stop > -1:
1397             whereClause.append("stopR < %d" % stop)
1398
1399         if len(whereClause) > 0:
1400             whereStatement = string.join(whereClause, " and ")
1401             whereQuery = "where %s" % whereStatement
1402         else:
1403             whereQuery = ""
1404
1405         selectClause = ["select ID, chrom, startL, stopL, startR, stopR, readID"]
1406         if not noSense:
1407             selectClause.append("sense")
1408
1409         if withWeight:
1410             selectClause.append("weight")
1411
1412         if withFlag:
1413             selectClause.append("flag")
1414
1415         if withMismatch:
1416             selectClause.append("mismatch")
1417
1418         selectQuery = string.join(selectClause, " ,")
1419         if self.memBacked:
1420             sql = self.memcon.cursor()
1421         else:
1422             sql = self.dbcon.cursor()
1423
1424         if chrom == "" and not readIDDict:
1425             stmt = "select distinct chrom from splices %s" % whereQuery
1426             sql.execute(stmt)
1427             for row in sql:
1428                 if fullChrom:
1429                     chrom = row["chrom"]
1430                 else:
1431                     chrom = row["chrom"][3:]
1432
1433                 resultsDict[chrom] = []
1434         elif chrom != "" and not readIDDict:
1435             resultsDict[chrom] = []
1436
1437         stmt = "%s from splices %s order by chrom, startL" % (selectQuery, whereQuery)
1438         sql.execute(stmt)
1439         currentReadID = ""
1440         for row in sql:
1441             pairID = 0
1442             readID = row["readID"]
1443             if fullChrom:
1444                 chrom = row["chrom"]
1445             else:
1446                 chrom = row["chrom"][3:]
1447
1448             if readIDDict:
1449                 if "/" in readID:
1450                     (theReadID, pairID) = readID.split("/")
1451                 else:
1452                     theReadID = readID
1453
1454                 if theReadID != currentReadID:
1455                     resultsDict[theReadID] = []
1456                     currentReadID = theReadID
1457                     dictKey = theReadID
1458             else:
1459                 dictKey = chrom
1460
1461             if entryDict:
1462                 newrow = {"startL": int(row["startL"])}
1463                 newrow["stopL"] = int(row["stopL"])
1464                 newrow["startR"] = int(row["startR"])
1465                 newrow["stopR"] = int(row["stopR"])
1466                 if not noSense:
1467                     newrow["sense"] = row["sense"]
1468
1469                 if withWeight:
1470                     newrow["weight"] = float(row["weight"])
1471
1472                 if withFlag:
1473                     newrow["flag"] = row["flag"]
1474
1475                 if withMismatch:
1476                     newrow["mismatch"] = row["mismatch"]
1477
1478                 if withID:
1479                     newrow["readID"] = readID
1480
1481                 if withChrom:
1482                     newrow["chrom"] = chrom
1483
1484                 if withPairID:
1485                     newrow["pairID"] = pairID
1486
1487                 if splitRead:
1488                     leftDict = newrow
1489                     del leftDict["startR"]
1490                     del leftDict["stopR"]
1491                     rightDict = newrow
1492                     del rightDict["start"]
1493                     del rightDict["stopL"]
1494                     resultsDict[dictKey].append(leftDict)
1495                     resultsDict[dictKey].append(rightDict)
1496                 else:
1497                     resultsDict[dictKey].append(newrow)
1498             else:
1499                 newrow = [int(row["startL"])]
1500                 newrow.append(int(row["stopL"]))
1501                 newrow.append(int(row["startR"]))
1502                 newrow.append(int(row["stopR"]))
1503                 if not noSense:
1504                     newrow.append(row["sense"])
1505
1506                 if withWeight:
1507                     newrow.append(float(row["weight"]))
1508
1509                 if withFlag:
1510                     newrow.append(row["flag"])
1511
1512                 if withMismatch:
1513                     newrow.append(row["mismatch"])
1514
1515                 if withID:
1516                     newrow.append(readID)
1517
1518                 if withChrom:
1519                     newrow.append(chrom)
1520
1521                 if withPairID:
1522                     newrow.append(pairID)
1523
1524                 if splitRead:
1525                     resultsDict[dictKey].append(newrow[:2] + newrow[4:])
1526                     resultsDict[dictKey].append(newrow[2:])
1527                 else:
1528                     resultsDict[dictKey].append(newrow)
1529
1530         return resultsDict
1531
1532
1533     def getCounts(self, chrom="", rmin="", rmax="", uniqs=True, multi=False,
1534                   splices=False, reportCombined=True, sense="both"):
1535         """ return read counts for a given region.
1536         """
1537         ucount = 0
1538         mcount = 0
1539         scount = 0
1540         restrict = ""
1541         if sense in ["+", "-"]:
1542             restrict = " sense ='%s' " % sense
1543
1544         if uniqs:
1545             try:
1546                 ucount = float(self.getUniqsCount(chrom, rmin, rmax, restrict))
1547             except:
1548                 ucount = 0
1549
1550         if multi:
1551             try:
1552                 mcount = float(self.getMultiCount(chrom, rmin, rmax, restrict))
1553             except:
1554                 mcount = 0
1555
1556         if splices:
1557             try:
1558                 scount = float(self.getSplicesCount(chrom, rmin, rmax, restrict))
1559             except:
1560                 scount = 0
1561
1562         if reportCombined:
1563             total = ucount + mcount + scount
1564             return total
1565         else:
1566             return (ucount, mcount, scount)
1567
1568
1569     def getTotalCounts(self, chrom="", rmin="", rmax=""):
1570         return self.getCounts(chrom, rmin, rmax, uniqs=True, multi=True, splices=True, reportCombined=True, sense="both")
1571
1572
1573     def getTableEntryCount(self, table, chrom="", rmin="", rmax="", restrict="", distinct=False, startField="start"):
1574         """ returns the number of row in the uniqs table.
1575         """
1576         whereClause = []
1577         count = 0
1578
1579         if chrom !=""  and chrom != self.memChrom:
1580             whereClause = ["chrom='%s'" % chrom]
1581
1582         if rmin != "":
1583             whereClause.append("%s >= %s" % (startField, str(rmin)))
1584
1585         if rmax != "":
1586             whereClause.append("%s <= %s" % (startField, str(rmax)))
1587
1588         if restrict != "":
1589             whereClause.append(restrict)
1590
1591         if len(whereClause) > 0:
1592             whereStatement = string.join(whereClause, " and ")
1593             whereQuery = "where %s" % whereStatement
1594         else:
1595             whereQuery = ""
1596
1597         if self.memBacked:
1598             sql = self.memcon.cursor()
1599         else:
1600             sql = self.dbcon.cursor()
1601
1602         if distinct:
1603             sql.execute("select count(distinct chrom+start+sense) from %s %s" % (table, whereQuery))
1604         else:
1605             sql.execute("select sum(weight) from %s %s" % (table, whereQuery))
1606
1607         result = sql.fetchone()
1608
1609         try:
1610             count = int(result[0])
1611         except:
1612             count = 0
1613
1614         return count
1615
1616
1617     def getSplicesCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
1618         """ returns the number of row in the splices table.
1619         """
1620         return self.getTableEntryCount("splices", chrom, rmin, rmax, restrict, distinct, startField="startL")
1621
1622
1623     def getUniqsCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
1624         """ returns the number of distinct readIDs in the uniqs table.
1625         """
1626         return self.getTableEntryCount("uniqs", chrom, rmin, rmax, restrict, distinct)
1627
1628
1629     def getMultiCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
1630         """ returns the total weight of readIDs in the multi table.
1631         """
1632         return self.getTableEntryCount("multi", chrom, rmin, rmax, restrict, distinct)
1633
1634
1635     def getReadIDs(self, uniqs=True, multi=False, splices=False, paired=False, limit=-1):
1636         """ get readID's.
1637         """
1638         stmt = []
1639         limitPart = ""
1640         if limit > 0:
1641             limitPart = "LIMIT %d" % limit
1642
1643         if uniqs:
1644             stmt.append("select readID from uniqs")
1645
1646         if multi:
1647             stmt.append("select readID from multi")
1648
1649         if splices:
1650             stmt.append("select readID from splices")
1651
1652         if len(stmt) > 0:
1653             selectPart = string.join(stmt, " union ")
1654         else:
1655             selectPart = ""
1656
1657         sqlQuery = "%s group by readID %s" (selectPart, limitPart)
1658         if self.memBacked:
1659             sql = self.memcon.cursor()
1660         else:
1661             sql = self.dbcon.cursor()
1662
1663         sql.execute(sqlQuery)
1664         result = sql.fetchall()
1665
1666         if paired:
1667             return [x.split("/")[0][0] for x in result]
1668         else:
1669             return [x[0] for x in result]
1670
1671
1672     def getMismatches(self, mischrom = None, verbose=False, useSplices=True):
1673         """ returns the uniq and spliced mismatches in a dictionary.
1674         """
1675         revcomp = {"A": "T",
1676                    "T": "A",
1677                    "G": "C",
1678                    "C": "G",
1679                    "N": "N"
1680         }
1681
1682         readlen = self.getReadSize()
1683         if mischrom:
1684             hitChromList = [mischrom]
1685         else:
1686             hitChromList = self.getChromosomes()
1687             hitChromList.sort()
1688
1689         snpDict = {}
1690         for achrom in hitChromList:
1691             if verbose:
1692                 print "getting mismatches from chromosome %s" % (achrom)
1693
1694             snpDict[achrom] = []
1695             hitDict = self.getReadsDict(fullChrom=True, chrom=achrom, withMismatch=True, findallOptimize=False, hasMismatch=True)
1696             if useSplices and self.dataType == "RNA":
1697                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=achrom, withMismatch=True, readIDDict=True, hasMismatch=True)
1698                 spliceIDList = spliceDict.keys()
1699                 for k in spliceIDList:
1700                     (startpos, lefthalf, rightstart, endspos, sense, mismatches) = spliceDict[k][0]
1701                     spMismatchList = mismatches.split(",")
1702                     for mismatch in spMismatchList:
1703                         if "N" in mismatch:
1704                             continue
1705
1706                         change_len = len(mismatch)
1707                         if sense == "+":
1708                             change_from = mismatch[0]
1709                             change_base = mismatch[change_len-1]
1710                             change_pos = int(mismatch[1:change_len-1])
1711                         elif sense == "-":
1712                             change_from = revcomp[mismatch[0]]
1713                             change_base = revcomp[mismatch[change_len-1]]
1714                             change_pos = readlen - int(mismatch[1:change_len-1]) + 1
1715
1716                         firsthalf = int(lefthalf)-int(startpos)+1
1717                         secondhalf = 0
1718                         if int(change_pos) <= int(firsthalf):
1719                             change_at = startpos + change_pos - 1
1720                         else:
1721                             secondhalf = change_pos - firsthalf
1722                             change_at = rightstart + secondhalf
1723
1724                         snpDict[achrom].append([startpos, change_at, change_base, change_from])
1725
1726             if achrom not in hitDict:
1727                 continue
1728
1729             for (start, sense, mismatches) in hitDict[achrom]:
1730                 mismatchList = mismatches.split(",")
1731                 for mismatch in mismatchList:
1732                     if "N" in mismatch:
1733                         continue
1734
1735                     change_len = len(mismatch)
1736                     if sense == "+":
1737                         change_from = mismatch[0]
1738                         change_base = mismatch[change_len-1]
1739                         change_pos = int(mismatch[1:change_len-1])
1740                     elif sense == "-":
1741                         change_from = revcomp[mismatch[0]]
1742                         change_base = revcomp[mismatch[change_len-1]]
1743                         change_pos = readlen - int(mismatch[1:change_len-1]) + 1
1744
1745                     change_at = start + change_pos - 1
1746                     snpDict[achrom].append([start, change_at, change_base, change_from])
1747
1748         return snpDict
1749
1750
1751     def getChromProfile(self, chromosome, cstart=-1, cstop=-1, useMulti=True,
1752                         useSplices=False, normalizationFactor = 1.0, trackStrand=False,
1753                         keepStrand="both", shiftValue=0):
1754         """return a profile of the chromosome as an array of per-base read coverage....
1755             keepStrand = 'both', 'plusOnly', or 'minusOnly'.
1756             Will also shift position of unique and multireads (but not splices) if shift is a natural number
1757         """
1758         metadata = self.getMetadata()
1759         readlen = int(metadata["readsize"])
1760         dataType = metadata["dataType"]
1761         scale = 1. / normalizationFactor
1762         shift = {}
1763         shift["+"] = int(shiftValue)
1764         shift["-"] = -1 * int(shiftValue)
1765
1766         if cstop > 0:
1767             lastNT = self.getMaxCoordinate(chromosome, doMulti=useMulti, doSplices=useSplices) + readlen
1768         else:
1769             lastNT = cstop - cstart + readlen + shift["+"]
1770
1771         chromModel = array("f", [0.] * lastNT)
1772         hitDict = self.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, start=cstart, stop=cstop, findallOptimize=True)
1773         if cstart < 0:
1774             cstart = 0
1775
1776         for (hstart, sense, weight) in hitDict[chromosome]:
1777             hstart = hstart - cstart + shift[sense]
1778             for currentpos in range(hstart,hstart+readlen):
1779                 try:
1780                     if not trackStrand or (sense == "+" and keepStrand != "minusOnly"):
1781                         chromModel[currentpos] += scale * weight
1782                     elif sense == '-' and keepStrand != "plusOnly":
1783                         chromModel[currentpos] -= scale * weight
1784                 except:
1785                     continue
1786
1787         del hitDict
1788         if useSplices and dataType == "RNA":
1789             if cstop > 0:
1790                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=chromosome, withID=True, start=cstart, stop=cstop)
1791             else:
1792                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=chromosome, withID=True)
1793    
1794             if chromosome in spliceDict:
1795                 for (Lstart, Lstop, Rstart, Rstop, rsense, readName) in spliceDict[chromosome]:
1796                     if (Rstop - cstart) < lastNT:
1797                         for index in range(abs(Lstop - Lstart)):
1798                             currentpos = Lstart - cstart + index
1799                             # we only track unique splices
1800                             if not trackStrand or (rsense == "+" and keepStrand != "minusOnly"):
1801                                 chromModel[currentpos] += scale
1802                             elif rsense == "-" and keepStrand != "plusOnly":
1803                                 chromModel[currentpos] -= scale
1804
1805                         for index in range(abs(Rstop - Rstart)):
1806                             currentpos = Rstart - cstart + index
1807                             # we only track unique splices
1808                             if not trackStrand or (rsense == "+" and keepStrand != "minusOnly"):
1809                                 chromModel[currentpos] += scale
1810                             elif rsense == "-" and keepStrand != "plusOnly":
1811                                 chromModel[currentpos] -= scale
1812
1813             del spliceDict
1814
1815         return chromModel
1816
1817
1818     def insertMetadata(self, valuesList):
1819         """ inserts a list of (pname, pvalue) into the metadata
1820         table.
1821         """
1822         self.dbcon.executemany("insert into metadata(name, value) values (?,?)", valuesList)
1823         self.dbcon.commit()
1824
1825
1826     def updateMetadata(self, pname, newValue, originalValue=""):
1827         """ update a metadata field given the original value and the new value.
1828         """
1829         stmt = "update metadata set value='%s' where name='%s'" % (str(newValue), pname)
1830         if originalValue != "":
1831             stmt += " and value='%s' " % str(originalValue)
1832
1833         self.dbcon.execute(stmt)
1834         self.dbcon.commit()
1835
1836
1837     def insertUniqs(self, valuesList):
1838         """ inserts a list of (readID, chrom, start, stop, sense, weight, flag, mismatch)
1839         into the uniqs table.
1840         """
1841         self.dbcon.executemany("insert into uniqs(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", valuesList)
1842         self.dbcon.commit()
1843
1844
1845     def insertMulti(self, valuesList):
1846         """ inserts a list of (readID, chrom, start, stop, sense, weight, flag, mismatch)
1847         into the multi table.
1848         """
1849         self.dbcon.executemany("insert into multi(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", valuesList)
1850         self.dbcon.commit()
1851
1852
1853     def insertSplices(self, valuesList):
1854         """ inserts a list of (readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch)
1855         into the splices table.
1856         """
1857         self.dbcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", valuesList)
1858         self.dbcon.commit()
1859
1860
1861     def flagReads(self, regionsList, uniqs=True, multi=False, splices=False, sense="both"):
1862         """ update reads on file database in a list region of regions for a chromosome to have a new flag.
1863             regionsList must have 4 fields per region of the form (flag, chrom, start, stop) or, with
1864             sense set to '+' or '-', 5 fields per region of the form (flag, chrom, start, stop, sense).
1865         """
1866         restrict = ""
1867         if sense != "both":
1868             restrict = " and sense = ? "
1869
1870         if uniqs:
1871             self.dbcon.executemany("UPDATE uniqs SET flag = ? where chrom = ? and start >= ? and start < ? " + restrict, regionsList)
1872
1873         if multi:
1874             self.dbcon.executemany("UPDATE multi SET flag = ? where chrom = ? and start >= ? and start < ? " + restrict, regionsList)
1875
1876         if self.dataType == "RNA" and splices:
1877             self.dbcon.executemany("UPDATE splices SET flag = flag || ' L:' || ? where chrom = ? and startL >= ? and startL < ? " + restrict, regionsList)
1878             self.dbcon.executemany("UPDATE splices SET flag = flag || ' R:' || ? where chrom = ? and startR >= ? and startR < ? " + restrict, regionsList)
1879
1880         self.dbcon.commit()
1881
1882
1883     def setFlags(self, flag, uniqs=True, multi=True, splices=True):
1884         """ set the flag fields in the entire dataset to clear. Useful for rerunning an analysis from scratch.
1885         """
1886         if uniqs:
1887             self.dbcon.execute("UPDATE uniqs SET flag = '%s'" % flag)
1888
1889         if multi:
1890             self.dbcon.execute("UPDATE multi SET flag = '%s'" % flag)
1891
1892         if self.dataType == 'RNA' and splices:
1893             self.dbcon.execute("UPDATE splices SET flag = '%s'" % flag)
1894
1895         self.dbcon.commit()
1896
1897
1898     def resetFlags(self, uniqs=True, multi=True, splices=True):
1899         """ reset the flag fields in the entire dataset to clear. Useful for rerunning an analysis from scratch.
1900         """
1901         if uniqs:
1902             self.dbcon.execute("UPDATE uniqs SET flag = ''")
1903
1904         if multi:
1905             self.dbcon.execute("UPDATE multi SET flag = ''")
1906
1907         if self.dataType == "RNA" and splices:
1908             self.dbcon.execute("UPDATE splices SET flag = ''")
1909
1910         self.dbcon.commit()
1911
1912
1913     def reweighMultireads(self, readList):
1914         self.dbcon.executemany("UPDATE multi SET weight = ? where chrom = ? and start = ? and readID = ? ", readList)
1915
1916
1917     def setSynchronousPragma(self, value="ON"):
1918         try:
1919             self.dbcon.execute("PRAGMA SYNCHRONOUS = %s" % value)
1920         except:
1921             print "warning: couldn't set PRAGMA SYNCHRONOUS = %s" % value
1922
1923
1924     def setDBcache(self, cache, default=False):
1925         self.dbcon.execute("PRAGMA CACHE_SIZE = %d" % cache)
1926         if default:
1927             self.dbcon.execute('PRAGMA DEFAULT_CACHE_SIZE = %d' % cache)
1928
1929
1930     def execute(self, statement, returnResults=False, forceCommit=False):
1931         if self.memBacked:
1932             sql = self.memcon.cursor()
1933         else:
1934             sql = self.dbcon.cursor()
1935
1936         sql.execute(statement)
1937         if returnResults:
1938             result = sql.fetchall()
1939             return result
1940
1941         if forceCommit:
1942             if self.memBacked:
1943                 self.memcon.commit()
1944             else:
1945                 self.dbcon.commit()
1946
1947
1948     def buildIndex(self, cache=100000):
1949         """ Builds the file indeces for the main tables.
1950             Cache is the number of 1.5 kb pages to keep in memory.
1951             100000 pages translates into 150MB of RAM, which is our default.
1952         """
1953         if cache > self.getDefaultCacheSize():
1954             self.setDBcache(cache)
1955         self.setSynchronousPragma("OFF")
1956         self.dbcon.execute("CREATE INDEX uPosIndex on uniqs(chrom, start)")
1957         print "built uPosIndex"
1958         self.dbcon.execute("CREATE INDEX uChromIndex on uniqs(chrom)")
1959         print "built uChromIndex"
1960         self.dbcon.execute("CREATE INDEX mPosIndex on multi(chrom, start)")
1961         print "built mPosIndex"
1962         self.dbcon.execute("CREATE INDEX mChromIndex on multi(chrom)")
1963         print "built mChromIndex"
1964
1965         if self.dataType == "RNA":
1966             self.dbcon.execute("CREATE INDEX sPosIndex on splices(chrom, startL)")
1967             print "built sPosIndex"
1968             self.dbcon.execute("CREATE INDEX sPosIndex2 on splices(chrom, startR)")
1969             print "built sPosIndex2"
1970             self.dbcon.execute("CREATE INDEX sChromIndex on splices(chrom)")
1971             print "built sChromIndex"
1972
1973         self.dbcon.commit()
1974         self.setSynchronousPragma("ON")
1975
1976
1977     def dropIndex(self):
1978         """ drops the file indices for the main tables.
1979         """
1980         try:
1981             self.setSynchronousPragma("OFF")
1982             self.dbcon.execute("DROP INDEX uPosIndex")
1983             self.dbcon.execute("DROP INDEX uChromIndex")
1984             self.dbcon.execute("DROP INDEX mPosIndex")
1985             self.dbcon.execute("DROP INDEX mChromIndex")
1986
1987             if self.dataType == "RNA":
1988                 self.dbcon.execute("DROP INDEX sPosIndex")
1989                 try:
1990                     self.dbcon.execute("DROP INDEX sPosIndex2")
1991                 except:
1992                     pass
1993
1994                 self.dbcon.execute("DROP INDEX sChromIndex")
1995
1996             self.dbcon.commit()
1997         except:
1998             print "problem dropping index"
1999
2000         self.setSynchronousPragma("ON")
2001
2002
2003     def memSync(self, chrom="", index=False):
2004         """ makes a copy of the dataset into memory for faster access.
2005         Can be restricted to a "full" chromosome. Can also build the
2006         memory indices.
2007         """
2008         self.memcon = ""
2009         self.memcon = sqlite.connect(":memory:")
2010         self.initializeTables(self.memcon)
2011         cursor = self.dbcon.cursor()
2012         whereclause = ""
2013         if chrom != "":
2014             print "memSync %s" % chrom
2015             whereclause = " where chrom = '%s' " % chrom
2016             self.memChrom = chrom
2017         else:
2018             self.memChrom = ""
2019
2020         self.memcon.execute("PRAGMA temp_store = MEMORY")
2021         self.memcon.execute("PRAGMA CACHE_SIZE = 1000000")
2022         # copy metadata to memory
2023         self.memcon.execute("delete from metadata")
2024         results = cursor.execute("select name, value from metadata")
2025         results2 = []
2026         for row in results:
2027             results2.append((row["name"], row["value"]))
2028
2029         self.memcon.executemany("insert into metadata(name, value) values (?,?)", results2)
2030         # copy uniqs to memory
2031         results = cursor.execute("select chrom, start, stop, sense, weight, flag, mismatch, readID from uniqs" + whereclause)
2032         results2 = []
2033         for row in results:
2034             results2.append((row["readID"], row["chrom"], int(row["start"]), int(row["stop"]), row["sense"], row["weight"], row["flag"], row["mismatch"]))
2035
2036         self.memcon.executemany("insert into uniqs(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", results2)
2037         # copy multi to memory
2038         results = cursor.execute("select chrom, start, stop, sense, weight, flag, mismatch, readID from multi" + whereclause)
2039         results2 = []
2040         for row in results:
2041             results2.append((row["readID"], row["chrom"], int(row["start"]), int(row["stop"]), row["sense"], row["weight"], row["flag"], row["mismatch"]))
2042
2043         self.memcon.executemany("insert into multi(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", results2)
2044         # copy splices to memory
2045         if self.dataType == "RNA":
2046             results = cursor.execute("select chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch, readID from splices" + whereclause)
2047             results2 = []
2048             for row in results:
2049                 results2.append((row["readID"], row["chrom"], int(row["startL"]), int(row["stopL"]), int(row["startR"]), int(row["stopR"]), row["sense"], row["weight"], row["flag"], row["mismatch"]))
2050
2051             self.memcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, weight, sense, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", results2)
2052         if index:
2053             if chrom != "":
2054                 self.memcon.execute("CREATE INDEX uPosIndex on uniqs(start)")
2055                 self.memcon.execute("CREATE INDEX mPosIndex on multi(start)")
2056                 if self.dataType == "RNA":
2057                     self.memcon.execute("CREATE INDEX sPosLIndex on splices(startL)")
2058                     self.memcon.execute("CREATE INDEX sPosRIndex on splices(startR)")
2059             else:
2060                 self.memcon.execute("CREATE INDEX uPosIndex on uniqs(chrom, start)")
2061                 self.memcon.execute("CREATE INDEX mPosIndex on multi(chrom, start)")
2062                 if self.dataType == "RNA":
2063                     self.memcon.execute("CREATE INDEX sPosLIndex on splices(chrom, startL)")
2064                     self.memcon.execute("CREATE INDEX sPosRIndex on splices(chrom, startR)")
2065
2066         self.memBacked = True
2067         self.memcon.row_factory = sqlite.Row
2068         self.memcon.commit()