rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
[erange.git] / ReadDataset.py
1 import sqlite3 as sqlite
2 import string
3 import tempfile
4 import shutil
5 import os
6 from array import array
7 from commoncode import getReverseComplement, getConfigParser, getConfigOption
8
9 currentRDSVersion = "2.1"
10
11
12 class ReadDatasetError(Exception):
13     pass
14
15
16 class ReadDataset():
17     """ Class for storing reads from experiments. Assumes that custom scripts
18     will translate incoming data into a format that can be inserted into the
19     class using the insert* methods. Default class subtype ('DNA') includes
20     tables for unique and multireads, whereas 'RNA' subtype also includes a
21     splices table.
22     """
23
24     def __init__(self, datafile, initialize=False, datasetType="DNA", verbose=False, 
25                  cache=False, reportCount=True):
26         """ creates an rds datafile if initialize is set to true, otherwise
27         will append to existing tables. datasetType can be either 'DNA' or 'RNA'.
28         """
29         self.dbcon = ""
30         self.memcon = ""
31         self.dataType = ""
32         self.rdsVersion = currentRDSVersion
33         self.memBacked = False
34         self.memChrom = ""
35         self.memCursor = ""
36         self.cachedDBFile = ""
37
38         if initialize and datasetType not in ["DNA", "RNA"]:
39             raise ReadDatasetError("failed to initialize: datasetType must be 'DNA' or 'RNA'")
40
41         if cache:
42             if verbose:
43                 print "caching ...."
44
45             self.cacheDB(datafile)
46             dbFile = self.cachedDBFile
47         else:
48             dbFile = datafile
49
50         self.dbcon = sqlite.connect(dbFile)
51         self.dbcon.row_factory = sqlite.Row
52         self.dbcon.execute("PRAGMA temp_store = MEMORY")
53         if initialize:
54             self.dataType = datasetType
55             self.initializeTables(self.dbcon)
56         else:
57             metadata = self.getMetadata("dataType")
58             self.dataType = metadata["dataType"]
59
60         try:
61             metadata = self.getMetadata("rdsVersion")
62             self.rdsVersion = metadata["rdsVersion"]
63         except:
64             try:
65                 self.insertMetadata([("rdsVersion", float(currentRDSVersion))])
66             except IOError:
67                 print "could not add rdsVersion - read-only ?"
68                 self.rdsVersion = "pre-1.0"
69
70         if verbose:
71             self.printRDSInfo(datafile, reportCount, initialize)
72
73
74     def __len__(self):
75         """ return the number of usable reads in the dataset.
76         """
77         total = self.getUniqsCount()
78         total += self.getMultiCount()
79
80         if self.dataType == "RNA":
81             total += self.getSplicesCount()
82
83         total = int(total)
84
85         return total
86
87
88     def __del__(self):
89         """ cleanup copy in local cache, if present.
90         """
91         if self.cachedDBFile != "":
92             self.uncacheDB()
93
94
95     def printRDSInfo(self, datafile, reportCount, initialize):
96         if initialize:
97             print "INITIALIZED dataset %s" % datafile
98         else:
99             print "dataset %s" % datafile
100
101         metadata = self.getMetadata()
102         print "metadata:"
103         pnameList = metadata.keys()
104         pnameList.sort()
105         for pname in pnameList:
106             print "\t" + pname + "\t" + metadata[pname]
107
108         if reportCount and not initialize:
109             self.printReadCounts()
110
111         print "default cache size is %d pages" % self.getDefaultCacheSize()
112         if self.hasIndex():
113             print "found index"
114         else:
115             print "not indexed"
116
117
118     def printReadCounts(self):
119         ucount = self.getUniqsCount()
120         mcount = self.getMultiCount()
121         if self.dataType == "DNA":
122             print "\n%d unique reads and %d multireads" % (ucount, mcount)
123         elif self.dataType == "RNA":
124             scount = self.getSplicesCount()
125             print "\n%d unique reads, %d spliced reads and %d multireads" % (ucount, scount, mcount)
126
127
128     def cacheDB(self, filename):
129         """ copy geneinfoDB to a local cache.
130         """
131         configParser = getConfigParser()
132         cisTemp = getConfigOption(configParser, "general", "cistematic_temp", default="/tmp")
133         tempfile.tempdir = cisTemp
134         self.cachedDBFile =  "%s.db" % tempfile.mktemp()
135         shutil.copyfile(filename, self.cachedDBFile)
136
137
138     def saveCacheDB(self, filename):
139         """ copy geneinfoDB to a local cache.
140         """
141         shutil.copyfile(self.cachedDBFile, filename)
142
143
144     def uncacheDB(self):
145         """ delete geneinfoDB from local cache.
146         """
147         global cachedDBFile
148         if self.cachedDBFile != "":
149             try:
150                 os.remove(self.cachedDBFile)
151             except:
152                 print "could not delete %s" % self.cachedDBFile
153
154             self.cachedDB = ""
155
156
157     def attachDB(self, filename, dbName):
158         """ attach another database file to the readDataset.
159         """
160         stmt = "attach '%s' as %s" % (filename, dbName)
161         self.execute(stmt)
162
163
164     def detachDB(self, dbName):
165         """ detach a database file to the readDataset.
166         """
167         stmt = "detach %s" % (dbName)
168         self.execute(stmt)
169
170
171     def importFromDB(self, dbName, table, ascolumns="*", destcolumns="", flagged=""):
172         """ import into current RDS the table (with columns destcolumns,
173             with default all columns) from the database file asname,
174             using the column specification of ascolumns (default all).
175         """
176         stmt = "insert into %s %s select %s from %s.%s" % (table, destcolumns, ascolumns, dbName, table)
177         if flagged != "":
178             stmt += " where flag = '%s' " % flagged
179
180         self.executeCommit(stmt)
181
182
183     def getTables(self, dbName=""):
184         """ get a list of table names in a particular database file.
185         """
186         resultList = []
187         sql = self.getSqlCursor()
188
189         if dbName != "":
190             dbName = "%s." % dbName
191
192         stmt = "select name from %ssqlite_master where type='table'" % dbName
193         sql.execute(stmt)
194         results = sql.fetchall()
195
196         for row in results:
197             resultList.append(row["name"])
198
199         return resultList
200
201
202     def getSqlCursor(self):
203         if self.memBacked:
204             sql = self.getMemCursor()
205         else:
206             sql = self.getFileCursor()
207
208         return sql
209
210
211     def getMemCursor(self):
212         """ returns a cursor to memory database for low-level (SQL)
213         access to the data.
214         """
215         return self.memcon.cursor()
216
217
218     def getFileCursor(self):
219         """ returns a cursor to file database for low-level (SQL)
220         access to the data.
221         """
222         return self.dbcon.cursor()
223
224
225     def hasIndex(self):
226         """ return True if the RDS file has at least one index.
227         """
228         stmt = "select count(*) from sqlite_master where type='index'"
229         count = int(self.execute(stmt, returnResults=True)[0][0])
230
231         return count > 0
232
233
234     def initializeTables(self, dbConnection, cache=100000):
235         """ creates table schema in a database connection, which is
236         typically a database file or an in-memory database.
237         """
238         dbConnection.execute("PRAGMA DEFAULT_CACHE_SIZE = %d" % cache)
239         dbConnection.execute("create table metadata (name varchar, value varchar)")
240         dbConnection.execute("insert into metadata values('dataType','%s')" % self.dataType)
241         positionSchema = "start int, stop int"
242         tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
243         dbConnection.execute("create table uniqs %s" % tableSchema)
244         dbConnection.execute("create table multi %s" % tableSchema)
245         if self.dataType == "RNA":
246             positionSchema = "startL int, stopL int, startR int, stopR int"
247             tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
248             dbConnection.execute("create table splices %s" % tableSchema)
249
250             positionSchema = "startL int, stopL int, startR int, stopR int"
251             tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
252             dbConnection.execute("create table multisplices %s" % tableSchema)
253
254         dbConnection.commit()
255
256
257     def getMetadata(self, valueName=""):
258         """ returns a dictionary of metadata.
259         """
260         whereClause = ""
261         resultsDict = {}
262
263         if valueName != "":
264             whereClause = " where name='%s'" % valueName
265
266         sql = self.getSqlCursor()
267
268         sql.execute("select name, value from metadata %s" % whereClause)
269         results = sql.fetchall()
270
271         for row in results:
272             parameterName = row["name"]
273             parameterValue = row["value"]
274             if parameterName not in resultsDict:
275                 resultsDict[parameterName] = parameterValue
276             else:
277                 trying = True
278                 index = 2
279                 while trying:
280                     newName = string.join([parameterName, str(index)], ":")
281                     if newName not in resultsDict:
282                         resultsDict[newName] = parameterValue
283                         trying = False
284
285                     index += 1
286
287         return resultsDict
288
289
290     def getReadSize(self):
291         """ returns readsize if defined in metadata.
292         """
293         metadata = self.getMetadata()
294         if "readsize" not in metadata:
295             raise ReadDatasetError("no readsize parameter defined")
296         else:
297             readSize = metadata["readsize"]
298             if "import" in readSize:
299                 readSize = readSize.split()[0]
300
301             readSize = int(readSize)
302             if readSize < 0:
303                 raise ReadDatasetError("readsize is negative")
304
305             return readSize
306
307
308     def getDefaultCacheSize(self):
309         """ returns the default cache size.
310         """
311         return int(self.execute("PRAGMA DEFAULT_CACHE_SIZE", returnResults=True)[0][0])
312
313
314     def getChromosomes(self, table="uniqs", fullChrom=True):
315         """ returns a sorted list of distinct chromosomes in table.
316         """
317         statement = "select distinct chrom from %s" % table
318         sql = self.getSqlCursor()
319
320         sql.execute(statement)
321         results = []
322         for row in sql:
323             if fullChrom:
324                 if row["chrom"] not in results:
325                     results.append(row["chrom"])
326             else:
327                 shortName = row["chrom"][3:]
328                 if  len(shortName.strip()) > 0 and shortName not in results:
329                     results.append(shortName)
330
331         results.sort()
332
333         return results
334
335
336     def getMaxCoordinate(self, chrom, doUniqs=True,
337                          doMulti=False, doSplices=False):
338         """ returns the maximum coordinate for reads on a given chromosome.
339         """
340         maxCoord = 0
341
342         if doUniqs:
343             maxCoord = self.getMaxStartCoordinateInTable(chrom, "uniqs")
344
345         if doSplices:
346             spliceMax = self.getMaxStartCoordinateInTable(chrom, "splices", startField="startR")
347             maxCoord = max(spliceMax, maxCoord)
348
349         if doMulti:
350             multiMax = self.getMaxStartCoordinateInTable(chrom, "multi")
351             maxCoord = max(multiMax, maxCoord)
352
353         return maxCoord
354
355
356     def getMaxStartCoordinateInTable(self, chrom, table, startField="start"):
357         maxCoord = 0
358         sqlStatement = "select max(%s) from %s where chrom = '%s'" % (startField, table, chrom)
359         sql = self.getSqlCursor()
360         try:
361             sql.execute(sqlStatement)
362             maxCoord = int(sql.fetchall()[0][0])
363         except:
364             print "couldn't retrieve coordMax for chromosome %s" % chrom
365
366         return maxCoord
367
368
369     def getReadsDict(self, bothEnds=False, noSense=False, fullChrom=False, chrom="",
370                      flag="", withWeight=False, withFlag=False, withMismatch=False, withID=False,
371                      withChrom=False, withPairID=False, doUniqs=True, doMulti=False, findallOptimize=False,
372                      readIDDict=False, readLike="", start=-1, stop=-1, limit=-1, hasMismatch=False,
373                      flagLike=False, strand='', combine5p=False):
374         """ returns a dictionary of reads in a variety of formats
375         and which can be restricted by chromosome or custom-flag.
376         Returns unique reads by default, but can return multireads
377         with doMulti set to True.
378
379         """
380         #TODO: Need to rethink original design 1: Cannot have pairID without exporting as a readIDDict
381
382         whereQuery = self.getReadWhereQuery(chrom, flag, flagLike, start, stop, hasMismatch, strand, readLike)
383         if findallOptimize:
384             selectQuery = "select start, sense, sum(weight)"
385         else:
386             selectQuery = self.getReadSelectQuery("select ID, chrom, start, readID", noSense, withWeight, withFlag, withMismatch, bothEnds)
387
388         groupQuery = self.getReadGroupQuery(findallOptimize, limit, combine5p)
389         if doUniqs:
390             stmt = [selectQuery, "from uniqs", whereQuery, groupQuery]
391             if doMulti:
392                 stmt.append("UNION ALL")
393                 stmt.append(selectQuery)
394                 stmt.append("from multi")
395                 stmt.append(whereQuery)
396                 stmt.append(groupQuery)
397         else:
398             stmt = [selectQuery, "from multi", whereQuery]
399
400         if combine5p:
401             if findallOptimize:
402                 selectQuery = "select start, sense, weight, chrom"
403
404             if doUniqs:
405                 subSelect = [selectQuery, "from uniqs", whereQuery]
406                 if doMulti:
407                     subSelect.append("union all")
408                     subSelect.append(selectQuery)
409                     subSelect.append("from multi")
410                     subSelect.append(whereQuery)
411             else:
412                 subSelect = [selectQuery, "from multi", whereQuery]
413
414             sqlStmt = string.join(subSelect)
415             if findallOptimize:
416                 selectQuery = "select start, sense, sum(weight)"
417
418             stmt = [selectQuery, "from (", sqlStmt, ") group by chrom,start having ( count(start) > 1 and count(chrom) > 1) union",
419                     selectQuery, "from(", sqlStmt, ") group by chrom, start having ( count(start) = 1 and count(chrom) = 1)"]
420
421         if findallOptimize:
422             if self.memBacked:
423                 self.memcon.row_factory = None
424             else:
425                 self.dbcon.row_factory = None
426
427             stmt.append("order by start")
428         elif readIDDict:
429             stmt.append("order by readID, start")
430         else:
431             stmt.append("order by chrom, start")
432
433         sql = self.getSqlCursor()
434         sqlQuery = string.join(stmt)
435         sql.execute(sqlQuery)
436
437         resultsDict = {}
438         if findallOptimize:
439             resultsDict[chrom] = [{"start": int(row[0]), "sense": row[1], "weight": float(row[2])} for row in sql]
440             if self.memBacked:
441                 self.memcon.row_factory = sqlite.Row
442             else:
443                 self.dbcon.row_factory = sqlite.Row
444         else:
445             currentChrom = ""
446             currentReadID = ""
447             pairID = 0
448             for row in sql:
449                 readID = row["readID"]
450                 if fullChrom:
451                     chrom = row["chrom"]
452                 else:
453                     chrom = row["chrom"][3:]
454
455                 if not readIDDict and chrom != currentChrom:
456                     resultsDict[chrom] = []
457                     currentChrom = chrom
458                     dictKey = chrom
459                 elif readIDDict:
460                     theReadID = readID
461                     if "::" in readID:
462                         theReadID = readID.split("::")[0]
463
464                     if "/" in theReadID and withPairID:
465                         (theReadID, pairID) = readID.split("/")
466
467                     if theReadID != currentReadID:
468                         resultsDict[theReadID] = []
469                         currentReadID = theReadID
470                         dictKey = theReadID
471
472                 newrow = {"start": int(row["start"])}
473                 if bothEnds:
474                     newrow["stop"] = int(row["stop"])
475
476                 if not noSense:
477                     newrow["sense"] = row["sense"]
478
479                 if withWeight:
480                     newrow["weight"] = float(row["weight"])
481
482                 if withFlag:
483                     newrow["flag"] = row["flag"]
484
485                 if withMismatch:
486                     newrow["mismatch"] = row["mismatch"]
487
488                 if withID:
489                     newrow["readID"] = readID
490
491                 if withChrom:
492                     newrow["chrom"] = chrom
493
494                 if withPairID:
495                     newrow["pairID"] = pairID
496
497                 resultsDict[dictKey].append(newrow)
498
499         return resultsDict
500
501
502     def getReadWhereQuery(self, chrom, flag, flagLike, start, stop, hasMismatch, strand, readLike="", splice=False):
503         if splice:
504             startText = "startL"
505             stopText = "stopR"
506         else:
507             startText = "start"
508             stopText = "stop"
509
510         whereClause = []
511         if chrom != "" and chrom != self.memChrom:
512             whereClause.append("chrom = '%s'" % chrom)
513
514         if flag != "":
515             if flagLike:
516                 flagLikeClause = string.join(['flag LIKE "%', flag, '%"'], "")
517                 whereClause.append(flagLikeClause)
518             else:
519                 whereClause.append("flag = '%s'" % flag)
520
521         if start > -1:
522             whereClause.append("%s > %d" % (startText, start))
523
524         if stop > -1:
525             whereClause.append("%s < %d" % (stopText, stop))
526
527         if len(readLike) > 0:
528             readIDClause = string.join(["readID LIKE  '", readLike, "%'"], "")
529             whereClause.append(readIDClause)
530
531         if hasMismatch:
532             whereClause.append("mismatch != ''")
533
534         if strand in ["+", "-"]:
535             whereClause.append("sense = '%s'" % strand)
536
537         if len(whereClause) > 0:
538             whereStatement = string.join(whereClause, " and ")
539             whereQuery = "where %s" % whereStatement
540         else:
541             whereQuery = ""
542
543         return whereQuery
544
545
546     def getReadSelectQuery(self, baseSelect, noSense, withWeight, withFlag, withMismatch, bothEnds=False):
547
548         selectClause = [baseSelect]
549         if bothEnds:
550             selectClause.append("stop")
551
552         if not noSense:
553             selectClause.append("sense")
554
555         if withWeight:
556             selectClause.append("weight")
557
558         if withFlag:
559             selectClause.append("flag")
560
561         if withMismatch:
562             selectClause.append("mismatch")
563
564         selectQuery = string.join(selectClause, ",")
565
566         return selectQuery
567
568
569     def getReadGroupQuery(self, findallOptimize, limit, combine5p):
570         groupBy = []
571         if findallOptimize:
572             groupBy = ["GROUP BY start, sense"]
573
574         if limit > 0 and not combine5p:
575             groupBy.append("LIMIT %d" % limit)
576
577         groupQuery = string.join(groupBy)
578
579         return groupQuery
580
581
582     def getSplicesDict(self, noSense=False, fullChrom=False, chrom="",
583                        flag="", withWeight=False, withFlag=False, withMismatch=False,
584                        withID=False, withChrom=False, withPairID=False, readIDDict=False,
585                        splitRead=False, hasMismatch=False, flagLike=False, start=-1,
586                        stop=-1, strand=""):
587         """ returns a dictionary of spliced reads in a variety of
588         formats and which can be restricted by chromosome or custom-flag.
589         Returns unique spliced reads for now.
590         """
591         whereQuery = self.getReadWhereQuery(chrom, flag, flagLike, start, stop, hasMismatch, strand, splice=True)
592         selectClause = "select ID, chrom, startL, stopL, startR, stopR, readID"
593         selectQuery = self.getReadSelectQuery(selectClause, noSense, withWeight, withFlag, withMismatch)
594         sql = self.getSqlCursor()
595
596         stmt = "%s from splices %s order by chrom, startL" % (selectQuery, whereQuery)
597         sql.execute(stmt)
598         currentReadID = ""
599         currentChrom = ""
600         resultsDict = {}
601         for row in sql:
602             pairID = 0
603             readID = row["readID"]
604             if fullChrom:
605                 chrom = row["chrom"]
606             else:
607                 chrom = row["chrom"][3:]
608
609             if not readIDDict and chrom != currentChrom:
610                 resultsDict[chrom] = []
611                 currentChrom = chrom
612                 dictKey = chrom
613             elif readIDDict:
614                 if "/" in readID:
615                     (theReadID, pairID) = readID.split("/")
616                 else:
617                     theReadID = readID
618
619                 if theReadID != currentReadID:
620                     resultsDict[theReadID] = []
621                     currentReadID = theReadID
622                     dictKey = theReadID
623
624             newrow = {"startL": int(row["startL"])}
625             newrow["stopL"] = int(row["stopL"])
626             newrow["startR"] = int(row["startR"])
627             newrow["stopR"] = int(row["stopR"])
628             if not noSense:
629                 newrow["sense"] = row["sense"]
630
631             if withWeight:
632                 newrow["weight"] = float(row["weight"])
633
634             if withFlag:
635                 newrow["flag"] = row["flag"]
636
637             if withMismatch:
638                 newrow["mismatch"] = row["mismatch"]
639
640             if withID:
641                 newrow["readID"] = readID
642
643             if withChrom:
644                 newrow["chrom"] = chrom
645
646             if withPairID:
647                 newrow["pairID"] = pairID
648
649             if splitRead:
650                 leftDict = newrow.copy()
651                 del leftDict["startR"]
652                 del leftDict["stopR"]
653                 rightDict = newrow
654                 del rightDict["startL"]
655                 del rightDict["stopL"]
656                 resultsDict[dictKey].append(leftDict)
657                 resultsDict[dictKey].append(rightDict)
658             else:
659                 resultsDict[dictKey].append(newrow)
660
661         return resultsDict
662
663
664     def getCounts(self, chrom="", rmin="", rmax="", uniqs=True, multi=False,
665                   splices=False, reportCombined=True, sense="both"):
666         """ return read counts for a given region.
667         """
668         ucount = 0
669         mcount = 0
670         scount = 0
671         restrict = ""
672         if sense in ["+", "-"]:
673             restrict = " sense ='%s' " % sense
674
675         if uniqs:
676             try:
677                 ucount = float(self.getUniqsCount(chrom, rmin, rmax, restrict))
678             except:
679                 ucount = 0
680
681         if multi:
682             try:
683                 mcount = float(self.getMultiCount(chrom, rmin, rmax, restrict))
684             except:
685                 mcount = 0
686
687         if splices:
688             try:
689                 scount = float(self.getSplicesCount(chrom, rmin, rmax, restrict))
690             except:
691                 scount = 0
692
693         if reportCombined:
694             total = ucount + mcount + scount
695             return total
696         else:
697             return (ucount, mcount, scount)
698
699
700     def getTotalCounts(self, chrom="", rmin="", rmax=""):
701         """ return read counts for a given region.
702         """
703         return self.getCounts(chrom, rmin, rmax, uniqs=True, multi=True, splices=True, reportCombined=True, sense="both")
704
705
706     def getTableEntryCount(self, table, chrom="", rmin="", rmax="", restrict="", distinct=False, startField="start"):
707         """ returns the number of row in the specified table.
708         """
709         whereClause = []
710         count = 0
711
712         if chrom !=""  and chrom != self.memChrom:
713             whereClause = ["chrom='%s'" % chrom]
714
715         if rmin != "":
716             whereClause.append("%s >= %s" % (startField, str(rmin)))
717
718         if rmax != "":
719             whereClause.append("%s <= %s" % (startField, str(rmax)))
720
721         if restrict != "":
722             whereClause.append(restrict)
723
724         if len(whereClause) > 0:
725             whereStatement = string.join(whereClause, " and ")
726             whereQuery = "where %s" % whereStatement
727         else:
728             whereQuery = ""
729
730         sql = self.getSqlCursor()
731
732         if distinct:
733             sql.execute("select count(distinct chrom+%s+sense) from %s %s" % (startField, table, whereQuery))
734         else:
735             sql.execute("select sum(weight) from %s %s" % (table, whereQuery))
736
737         result = sql.fetchone()
738
739         try:
740             count = int(result[0])
741         except:
742             count = 0
743
744         return count
745
746
747     def getSplicesCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
748         """ returns the number of row in the splices table.
749         """
750         # TODO: if the rds type is DNA should this just return zero?
751         return self.getTableEntryCount("splices", chrom, rmin, rmax, restrict, distinct, startField="startL")
752
753
754     def getUniqsCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
755         """ returns the number of distinct readIDs in the uniqs table.
756         """
757         return self.getTableEntryCount("uniqs", chrom, rmin, rmax, restrict, distinct)
758
759
760     def getMultiCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
761         """ returns the total weight of readIDs in the multi table.
762         """
763         return self.getTableEntryCount("multi", chrom, rmin, rmax, restrict, distinct)
764
765
766     def getReadIDs(self, uniqs=True, multi=False, splices=False, paired=False, limit=-1):
767         """ get readID's.
768         """
769         stmt = []
770         if uniqs:
771             stmt.append("select readID from uniqs")
772
773         if multi:
774             stmt.append("select readID from multi")
775
776         if splices:
777             stmt.append("select readID from splices")
778
779         if len(stmt) > 0:
780             selectPart = string.join(stmt, " union ")
781         else:
782             selectPart = ""
783
784         limitPart = ""
785         if limit > 0:
786             limitPart = "LIMIT %d" % limit
787
788         sqlQuery = "%s group by readID %s" % (selectPart, limitPart)
789         sql = self.getSqlCursor()
790         sql.execute(sqlQuery)
791         result = sql.fetchall()
792
793         if paired:
794             return [x[0].split("/")[0] for x in result]
795         else:
796             return [x[0] for x in result]
797
798
799     def getMismatches(self, mischrom=None, verbose=False, useSplices=True):
800         """ returns the uniq and spliced mismatches in a dictionary.
801         """
802         readlen = self.getReadSize()
803         if mischrom:
804             hitChromList = [mischrom]
805         else:
806             hitChromList = self.getChromosomes()
807             hitChromList.sort()
808
809         snpDict = {}
810         for achrom in hitChromList:
811             if verbose:
812                 print "getting mismatches from chromosome %s" % (achrom)
813
814             snpDict[achrom] = []
815             if useSplices and self.dataType == "RNA":
816                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=achrom, withMismatch=True, readIDDict=True, hasMismatch=True)
817                 spliceIDList = spliceDict.keys()
818                 for spliceID in spliceIDList:
819                     spliceEntry = spliceDict[spliceID][0]
820                     startpos = spliceEntry["startL"]
821                     lefthalf = spliceEntry["stopL"]
822                     rightstart = spliceEntry["startR"]
823                     sense = spliceEntry["sense"]
824                     mismatches = spliceEntry["mismatch"]
825                     spMismatchList = mismatches.split(",")
826                     for mismatch in spMismatchList:
827                         if "N" in mismatch:
828                             continue
829
830                         change_len = len(mismatch)
831                         if sense == "+":
832                             change_from = mismatch[0]
833                             change_base = mismatch[change_len-1]
834                             change_pos = int(mismatch[1:change_len-1])
835                         elif sense == "-":
836                             change_from = getReverseComplement([mismatch[0]])
837                             change_base = getReverseComplement([mismatch[change_len-1]])
838                             change_pos = readlen - int(mismatch[1:change_len-1]) + 1
839
840                         firsthalf = int(lefthalf)-int(startpos)+1
841                         secondhalf = 0
842                         if int(change_pos) <= int(firsthalf):
843                             change_at = startpos + change_pos - 1
844                         else:
845                             secondhalf = change_pos - firsthalf
846                             change_at = rightstart + secondhalf
847
848                         snpDict[achrom].append([startpos, change_at, change_base, change_from])
849
850             hitDict = self.getReadsDict(fullChrom=True, chrom=achrom, withMismatch=True, hasMismatch=True)
851             if achrom not in hitDict.keys():
852                 continue
853
854             for readEntry in hitDict[achrom]:
855                 start = readEntry["start"]
856                 sense = readEntry["sense"]
857                 mismatches = readEntry["mismatch"]
858                 mismatchList = mismatches.split(",")
859                 for mismatch in mismatchList:
860                     if "N" in mismatch:
861                         continue
862
863                     change_len = len(mismatch)
864                     if sense == "+":
865                         change_from = mismatch[0]
866                         change_base = mismatch[change_len-1]
867                         change_pos = int(mismatch[1:change_len-1])
868                     elif sense == "-":
869                         change_from = getReverseComplement([mismatch[0]])
870                         change_base = getReverseComplement([mismatch[change_len-1]])
871                         change_pos = readlen - int(mismatch[1:change_len-1]) + 1
872
873                     change_at = start + change_pos - 1
874                     snpDict[achrom].append([start, change_at, change_base, change_from])
875
876         return snpDict
877
878
879     def getChromProfile(self, chromosome, cstart=-1, cstop=-1, useMulti=True,
880                         useSplices=False, normalizationFactor=1.0, trackStrand=False,
881                         keepStrand="both", shiftValue=0):
882         """return a profile of the chromosome as an array of per-base read coverage....
883             keepStrand = 'both', 'plusOnly', or 'minusOnly'.
884             Will also shift position of unique and multireads (but not splices) if shift is a natural number
885         """
886         metadata = self.getMetadata()
887         try:
888             readlen = int(metadata["readsize"])
889         except KeyError:
890             readlen = 0
891
892         dataType = metadata["dataType"]
893         scale = 1. / normalizationFactor
894         shift = {}
895         shift["+"] = int(shiftValue)
896         shift["-"] = -1 * int(shiftValue)
897
898         if cstop > 0:
899             lastNT = self.getMaxCoordinate(chromosome, doMulti=useMulti, doSplices=useSplices) + readlen
900         else:
901             lastNT = cstop - cstart + readlen + shift["+"]
902
903         chromModel = array("f",[0.] * lastNT)
904         hitDict = self.getReadsDict(fullChrom=True, chrom=chromosome, withWeight=True, doMulti=useMulti, start=cstart, stop=cstop, findallOptimize=True)
905         if cstart < 0:
906             cstart = 0
907
908         for readEntry in hitDict[chromosome]:
909             hstart = readEntry["start"]
910             sense =  readEntry ["sense"]
911             weight = readEntry["weight"]
912             hstart = hstart - cstart + shift[sense]
913             for currentpos in range(hstart,hstart+readlen):
914                 try:
915                     if not trackStrand or (sense == "+" and keepStrand != "minusOnly"):
916                         chromModel[currentpos] += scale * weight
917                     elif sense == "-" and keepStrand != "plusOnly":
918                         chromModel[currentpos] -= scale * weight
919                 except:
920                     continue
921
922         del hitDict
923         if useSplices and dataType == "RNA":
924             if cstop > 0:
925                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=chromosome, withID=True, start=cstart, stop=cstop)
926             else:
927                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=chromosome, withID=True)
928    
929             if chromosome in spliceDict:
930                 for spliceEntry in spliceDict[chromosome]:
931                     Lstart = spliceEntry["startL"]
932                     Lstop = spliceEntry["stopL"]
933                     Rstart = spliceEntry["startR"]
934                     Rstop = spliceEntry["stopR"]
935                     rsense = spliceEntry["sense"]
936                     if (Rstop - cstart) < lastNT:
937                         for index in range(abs(Lstop - Lstart)):
938                             currentpos = Lstart - cstart + index
939                             # we only track unique splices
940                             if not trackStrand or (rsense == "+" and keepStrand != "minusOnly"):
941                                 chromModel[currentpos] += scale
942                             elif rsense == "-" and keepStrand != "plusOnly":
943                                 chromModel[currentpos] -= scale
944
945                         for index in range(abs(Rstop - Rstart)):
946                             currentpos = Rstart - cstart + index
947                             # we only track unique splices
948                             if not trackStrand or (rsense == "+" and keepStrand != "minusOnly"):
949                                 chromModel[currentpos] += scale
950                             elif rsense == "-" and keepStrand != "plusOnly":
951                                 chromModel[currentpos] -= scale
952
953             del spliceDict
954
955         return chromModel
956
957
958     def insertMetadata(self, valuesList):
959         """ inserts a list of (pname, pvalue) into the metadata
960         table.
961         """
962         self.dbcon.executemany("insert into metadata(name, value) values (?,?)", valuesList)
963         self.dbcon.commit()
964
965
966     def updateMetadata(self, pname, newValue, originalValue=""):
967         """ update a metadata field given the original value and the new value.
968         """
969         stmt = "update metadata set value='%s' where name='%s'" % (str(newValue), pname)
970         if originalValue != "":
971             stmt += " and value='%s' " % str(originalValue)
972
973         self.dbcon.execute(stmt)
974         self.dbcon.commit()
975
976
977     def insertUniqs(self, valuesList):
978         """ inserts a list of (readID, chrom, start, stop, sense, weight, flag, mismatch)
979         into the uniqs table.
980         """
981         self.dbcon.executemany("insert into uniqs(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", valuesList)
982         self.dbcon.commit()
983
984
985     def insertMulti(self, valuesList):
986         """ inserts a list of (readID, chrom, start, stop, sense, weight, flag, mismatch)
987         into the multi table.
988         """
989         self.dbcon.executemany("insert into multi(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", valuesList)
990         self.dbcon.commit()
991
992
993     def insertSplices(self, valuesList):
994         """ inserts a list of (readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch)
995         into the splices table.
996         """
997         self.dbcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", valuesList)
998         self.dbcon.commit()
999
1000
1001     def insertMultisplices(self, valuesList):
1002         """ inserts a list of (readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch)
1003         into the multisplices table.
1004         """
1005         self.dbcon.executemany("insert into multisplices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", valuesList)
1006         self.dbcon.commit()
1007
1008
1009     def flagReads(self, regionsList, uniqs=True, multi=False, splices=False, sense="both"):
1010         """ update reads on file database in a list region of regions for a chromosome to have a new flag.
1011             regionsList must have 4 fields per region of the form (flag, chrom, start, stop) or, with
1012             sense set to '+' or '-', 5 fields per region of the form (flag, chrom, start, stop, sense).
1013         """
1014         restrict = ""
1015         if sense != "both":
1016             restrict = " and sense = ? "
1017
1018         if uniqs:
1019             self.dbcon.executemany("UPDATE uniqs SET flag = ? where chrom = ? and start >= ? and start < ? " + restrict, regionsList)
1020
1021         if multi:
1022             self.dbcon.executemany("UPDATE multi SET flag = ? where chrom = ? and start >= ? and start < ? " + restrict, regionsList)
1023
1024         if self.dataType == "RNA" and splices:
1025             self.dbcon.executemany("UPDATE splices SET flag = flag || ' L:' || ? where chrom = ? and startL >= ? and startL < ? " + restrict, regionsList)
1026             self.dbcon.executemany("UPDATE splices SET flag = flag || ' R:' || ? where chrom = ? and startR >= ? and startR < ? " + restrict, regionsList)
1027
1028         self.dbcon.commit()
1029
1030
1031     def setFlags(self, flag, uniqs=True, multi=True, splices=True):
1032         """ set the flag fields in the entire dataset.
1033         """
1034         if uniqs:
1035             self.dbcon.execute("UPDATE uniqs SET flag = '%s'" % flag)
1036
1037         if multi:
1038             self.dbcon.execute("UPDATE multi SET flag = '%s'" % flag)
1039
1040         if self.dataType == "RNA" and splices:
1041             self.dbcon.execute("UPDATE splices SET flag = '%s'" % flag)
1042
1043         self.dbcon.commit()
1044
1045
1046     def resetFlags(self, uniqs=True, multi=True, splices=True):
1047         """ reset the flag fields in the entire dataset to clear. Useful for rerunning an analysis from scratch.
1048         """
1049         self.setFlags("", uniqs, multi, splices)
1050
1051
1052     def reweighMultireads(self, readList):
1053         self.dbcon.executemany("UPDATE multi SET weight = ? where chrom = ? and start = ? and readID = ? ", readList)
1054
1055
1056     def setSynchronousPragma(self, value="ON"):
1057         try:
1058             self.dbcon.execute("PRAGMA SYNCHRONOUS = %s" % value)
1059         except:
1060             print "warning: couldn't set PRAGMA SYNCHRONOUS = %s" % value
1061
1062
1063     def setDBcache(self, cache, default=False):
1064         self.dbcon.execute("PRAGMA CACHE_SIZE = %d" % cache)
1065         if default:
1066             self.dbcon.execute("PRAGMA DEFAULT_CACHE_SIZE = %d" % cache)
1067
1068
1069     def execute(self, statement, returnResults=False):
1070         sql = self.getSqlCursor()
1071
1072         sql.execute(statement)
1073         if returnResults:
1074             result = sql.fetchall()
1075             return result
1076
1077
1078     def executeCommit(self, statement):
1079         self.execute(statement)
1080
1081         if self.memBacked:
1082             self.memcon.commit()
1083         else:
1084             self.dbcon.commit()
1085
1086
1087     def buildIndex(self, cache=100000):
1088         """ Builds the file indeces for the main tables.
1089             Cache is the number of 1.5 kb pages to keep in memory.
1090             100000 pages translates into 150MB of RAM, which is our default.
1091         """
1092         if cache > self.getDefaultCacheSize():
1093             self.setDBcache(cache)
1094         self.setSynchronousPragma("OFF")
1095         self.dbcon.execute("CREATE INDEX uPosIndex on uniqs(chrom, start)")
1096         print "built uPosIndex"
1097         self.dbcon.execute("CREATE INDEX uChromIndex on uniqs(chrom)")
1098         print "built uChromIndex"
1099         self.dbcon.execute("CREATE INDEX mPosIndex on multi(chrom, start)")
1100         print "built mPosIndex"
1101         self.dbcon.execute("CREATE INDEX mChromIndex on multi(chrom)")
1102         print "built mChromIndex"
1103
1104         if self.dataType == "RNA":
1105             self.dbcon.execute("CREATE INDEX sPosIndex on splices(chrom, startL)")
1106             print "built sPosIndex"
1107             self.dbcon.execute("CREATE INDEX sPosIndex2 on splices(chrom, startR)")
1108             print "built sPosIndex2"
1109             self.dbcon.execute("CREATE INDEX sChromIndex on splices(chrom)")
1110             print "built sChromIndex"
1111
1112         self.dbcon.commit()
1113         self.setSynchronousPragma("ON")
1114
1115
1116     def dropIndex(self):
1117         """ drops the file indices for the main tables.
1118         """
1119         try:
1120             self.setSynchronousPragma("OFF")
1121             self.dbcon.execute("DROP INDEX uPosIndex")
1122             self.dbcon.execute("DROP INDEX uChromIndex")
1123             self.dbcon.execute("DROP INDEX mPosIndex")
1124             self.dbcon.execute("DROP INDEX mChromIndex")
1125
1126             if self.dataType == "RNA":
1127                 self.dbcon.execute("DROP INDEX sPosIndex")
1128                 try:
1129                     self.dbcon.execute("DROP INDEX sPosIndex2")
1130                 except:
1131                     pass
1132
1133                 self.dbcon.execute("DROP INDEX sChromIndex")
1134
1135             self.dbcon.commit()
1136         except:
1137             print "problem dropping index"
1138
1139         self.setSynchronousPragma("ON")
1140
1141
1142     def memSync(self, chrom="", index=False):
1143         """ makes a copy of the dataset into memory for faster access.
1144         Can be restricted to a "full" chromosome. Can also build the
1145         memory indices.
1146         """
1147         self.memcon = ""
1148         self.memcon = sqlite.connect(":memory:")
1149         self.initializeTables(self.memcon)
1150         cursor = self.dbcon.cursor()
1151         whereclause = ""
1152         if chrom != "":
1153             print "memSync %s" % chrom
1154             whereclause = " where chrom = '%s' " % chrom
1155             self.memChrom = chrom
1156         else:
1157             self.memChrom = ""
1158
1159         self.memcon.execute("PRAGMA temp_store = MEMORY")
1160         self.memcon.execute("PRAGMA CACHE_SIZE = 1000000")
1161         # copy metadata to memory
1162         self.memcon.execute("delete from metadata")
1163         results = cursor.execute("select name, value from metadata")
1164         results2 = []
1165         for row in results:
1166             results2.append((row["name"], row["value"]))
1167
1168         self.memcon.executemany("insert into metadata(name, value) values (?,?)", results2)
1169
1170         self.copyDBEntriesToMemory("uniqs", whereclause)
1171         self.copyDBEntriesToMemory("multi", whereclause)
1172         if self.dataType == "RNA":
1173             self.copySpliceDBEntriesToMemory(whereclause)
1174
1175         if index:
1176             if chrom != "":
1177                 self.memcon.execute("CREATE INDEX uPosIndex on uniqs(start)")
1178                 self.memcon.execute("CREATE INDEX mPosIndex on multi(start)")
1179                 if self.dataType == "RNA":
1180                     self.memcon.execute("CREATE INDEX sPosLIndex on splices(startL)")
1181                     self.memcon.execute("CREATE INDEX sPosRIndex on splices(startR)")
1182             else:
1183                 self.memcon.execute("CREATE INDEX uPosIndex on uniqs(chrom, start)")
1184                 self.memcon.execute("CREATE INDEX mPosIndex on multi(chrom, start)")
1185                 if self.dataType == "RNA":
1186                     self.memcon.execute("CREATE INDEX sPosLIndex on splices(chrom, startL)")
1187                     self.memcon.execute("CREATE INDEX sPosRIndex on splices(chrom, startR)")
1188
1189         self.memBacked = True
1190         self.memcon.row_factory = sqlite.Row
1191         self.memcon.commit()
1192
1193
1194     def copyDBEntriesToMemory(self, dbName, whereClause=""):
1195         cursor = self.dbcon.cursor()
1196         sourceEntries = cursor.execute("select chrom, start, stop, sense, weight, flag, mismatch, readID from %s %s" % (dbName, whereClause))
1197         destinationEntries = []
1198         for row in sourceEntries:
1199             destinationEntries.append((row["readID"], row["chrom"], int(row["start"]), int(row["stop"]), row["sense"], row["weight"], row["flag"], row["mismatch"]))
1200
1201         self.memcon.executemany("insert into %s(ID, readID, chrom, start, stop, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)" % dbName, destinationEntries)
1202
1203
1204     def copySpliceDBEntriesToMemory(self, whereClause=""):
1205         cursor = self.dbcon.cursor()
1206         sourceEntries = cursor.execute("select chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch, readID from splices %s" % whereClause)
1207         destinationEntries = []
1208         for row in sourceEntries:
1209             destinationEntries.append((row["readID"], row["chrom"], int(row["startL"]), int(row["stopL"]), int(row["startR"]), int(row["stopR"]), row["sense"],
1210                                        row["weight"], row["flag"], row["mismatch"]))
1211
1212         self.memcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", destinationEntries)
1213