rewrite of findall.py and MakeRdsFromBam to fix bugs resulting from poor initial...
[erange.git] / ReadDataset.py
index 56e28a8995b734735ace06d1d990ac9b46753580..71544ca2420d2489fca13452976d930a6c55a9c5 100644 (file)
@@ -6,7 +6,7 @@ import os
 from array import array
 from commoncode import getReverseComplement, getConfigParser, getConfigOption
 
-currentRDSVersion = "2.0"
+currentRDSVersion = "2.1"
 
 
 class ReadDatasetError(Exception):
@@ -35,6 +35,9 @@ class ReadDataset():
         self.memCursor = ""
         self.cachedDBFile = ""
 
+        if initialize and datasetType not in ["DNA", "RNA"]:
+            raise ReadDatasetError("failed to initialize: datasetType must be 'DNA' or 'RNA'")
+
         if cache:
             if verbose:
                 print "caching ...."
@@ -48,11 +51,7 @@ class ReadDataset():
         self.dbcon.row_factory = sqlite.Row
         self.dbcon.execute("PRAGMA temp_store = MEMORY")
         if initialize:
-            if datasetType not in ["DNA", "RNA"]:
-                raise ReadDatasetError("failed to initialize: datasetType must be 'DNA' or 'RNA'")
-            else:
-                self.dataType = datasetType
-
+            self.dataType = datasetType
             self.initializeTables(self.dbcon)
         else:
             metadata = self.getMetadata("dataType")
@@ -69,38 +68,7 @@ class ReadDataset():
                 self.rdsVersion = "pre-1.0"
 
         if verbose:
-            if initialize:
-                print "INITIALIZED dataset %s" % datafile
-            else:
-                print "dataset %s" % datafile
-
-            metadata = self.getMetadata()
-            print "metadata:"
-            pnameList = metadata.keys()
-            pnameList.sort()
-            for pname in pnameList:
-                print "\t" + pname + "\t" + metadata[pname]
-
-            if reportCount:
-                ucount = self.getUniqsCount()
-                mcount = self.getMultiCount()
-                if self.dataType == "DNA" and not initialize:
-                    try:
-                        print "\n%d unique reads and %d multireads" % (int(ucount), int(mcount))
-                    except ValueError:
-                        print "\n%s unique reads and %s multireads" % (ucount, mcount)
-                elif self.dataType == "RNA" and not initialize:
-                    scount = self.getSplicesCount()
-                    try:
-                        print "\n%d unique reads, %d spliced reads and %d multireads" % (int(ucount), int(scount), int(mcount))
-                    except ValueError:
-                        print "\n%s unique reads, %s spliced reads and %s multireads" % (ucount, scount, mcount)
-
-            print "default cache size is %d pages" % self.getDefaultCacheSize()
-            if self.hasIndex():
-                print "found index"
-            else:
-                print "not indexed"
+            self.printRDSInfo(datafile, reportCount, initialize)
 
 
     def __len__(self):
@@ -124,6 +92,39 @@ class ReadDataset():
             self.uncacheDB()
 
 
+    def printRDSInfo(self, datafile, reportCount, initialize):
+        if initialize:
+            print "INITIALIZED dataset %s" % datafile
+        else:
+            print "dataset %s" % datafile
+
+        metadata = self.getMetadata()
+        print "metadata:"
+        pnameList = metadata.keys()
+        pnameList.sort()
+        for pname in pnameList:
+            print "\t" + pname + "\t" + metadata[pname]
+
+        if reportCount and not initialize:
+            self.printReadCounts()
+
+        print "default cache size is %d pages" % self.getDefaultCacheSize()
+        if self.hasIndex():
+            print "found index"
+        else:
+            print "not indexed"
+
+
+    def printReadCounts(self):
+        ucount = self.getUniqsCount()
+        mcount = self.getMultiCount()
+        if self.dataType == "DNA":
+            print "\n%d unique reads and %d multireads" % (ucount, mcount)
+        elif self.dataType == "RNA":
+            scount = self.getSplicesCount()
+            print "\n%d unique reads, %d spliced reads and %d multireads" % (ucount, scount, mcount)
+
+
     def cacheDB(self, filename):
         """ copy geneinfoDB to a local cache.
         """
@@ -207,15 +208,27 @@ class ReadDataset():
         return sql
 
 
+    def getMemCursor(self):
+        """ returns a cursor to memory database for low-level (SQL)
+        access to the data.
+        """
+        return self.memcon.cursor()
+
+
+    def getFileCursor(self):
+        """ returns a cursor to file database for low-level (SQL)
+        access to the data.
+        """
+        return self.dbcon.cursor()
+
+
     def hasIndex(self):
-        """ check whether the RDS file has at least one index.
+        """ return True if the RDS file has at least one index.
         """
         stmt = "select count(*) from sqlite_master where type='index'"
         count = int(self.execute(stmt, returnResults=True)[0][0])
-        if count > 0:
-            return True
 
-        return False
+        return count > 0
 
 
     def initializeTables(self, dbConnection, cache=100000):
@@ -234,21 +247,11 @@ class ReadDataset():
             tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
             dbConnection.execute("create table splices %s" % tableSchema)
 
-        dbConnection.commit()
-
-
-    def getFileCursor(self):
-        """ returns a cursor to file database for low-level (SQL)
-        access to the data.
-        """
-        return self.dbcon.cursor()
-
+            positionSchema = "startL int, stopL int, startR int, stopR int"
+            tableSchema = "(ID INTEGER PRIMARY KEY, readID varchar, chrom varchar, %s, sense varchar, weight real, flag varchar, mismatch varchar)" % positionSchema
+            dbConnection.execute("create table multisplices %s" % tableSchema)
 
-    def getMemCursor(self):
-        """ returns a cursor to memory database for low-level (SQL)
-        access to the data.
-        """
-        return self.memcon.cursor()
+        dbConnection.commit()
 
 
     def getMetadata(self, valueName=""):
@@ -309,7 +312,7 @@ class ReadDataset():
 
 
     def getChromosomes(self, table="uniqs", fullChrom=True):
-        """ returns a list of distinct chromosomes in table.
+        """ returns a sorted list of distinct chromosomes in table.
         """
         statement = "select distinct chrom from %s" % table
         sql = self.getSqlCursor()
@@ -330,7 +333,7 @@ class ReadDataset():
         return results
 
 
-    def getMaxCoordinate(self, chrom, verbose=False, doUniqs=True,
+    def getMaxCoordinate(self, chrom, doUniqs=True,
                          doMulti=False, doSplices=False):
         """ returns the maximum coordinate for reads on a given chromosome.
         """
@@ -347,9 +350,6 @@ class ReadDataset():
             multiMax = self.getMaxStartCoordinateInTable(chrom, "multi")
             maxCoord = max(multiMax, maxCoord)
 
-        if verbose:
-            print "%s maxCoord: %d" % (chrom, maxCoord)
-
         return maxCoord
 
 
@@ -375,9 +375,9 @@ class ReadDataset():
         and which can be restricted by chromosome or custom-flag.
         Returns unique reads by default, but can return multireads
         with doMulti set to True.
-        
-        Need to rethink original design 1: Cannot have pairID without exporting as a readIDDict
+
         """
+        #TODO: Need to rethink original design 1: Cannot have pairID without exporting as a readIDDict
 
         whereQuery = self.getReadWhereQuery(chrom, flag, flagLike, start, stop, hasMismatch, strand, readLike)
         if findallOptimize:
@@ -421,27 +421,16 @@ class ReadDataset():
         if findallOptimize:
             if self.memBacked:
                 self.memcon.row_factory = None
-                sql = self.memcon.cursor()
             else:
                 self.dbcon.row_factory = None
-                sql = self.dbcon.cursor()
 
             stmt.append("order by start")
         elif readIDDict:
-            if self.memBacked:
-                sql = self.memcon.cursor()
-            else:
-                sql = self.dbcon.cursor()
-
             stmt.append("order by readID, start")
         else:
-            if self.memBacked:
-                sql = self.memcon.cursor()
-            else:
-                sql = self.dbcon.cursor()
-
             stmt.append("order by chrom, start")
 
+        sql = self.getSqlCursor()
         sqlQuery = string.join(stmt)
         sql.execute(sqlQuery)
 
@@ -602,10 +591,7 @@ class ReadDataset():
         whereQuery = self.getReadWhereQuery(chrom, flag, flagLike, start, stop, hasMismatch, strand, splice=True)
         selectClause = "select ID, chrom, startL, stopL, startR, stopR, readID"
         selectQuery = self.getReadSelectQuery(selectClause, noSense, withWeight, withFlag, withMismatch)
-        if self.memBacked:
-            sql = self.memcon.cursor()
-        else:
-            sql = self.dbcon.cursor()
+        sql = self.getSqlCursor()
 
         stmt = "%s from splices %s order by chrom, startL" % (selectQuery, whereQuery)
         sql.execute(stmt)
@@ -718,7 +704,7 @@ class ReadDataset():
 
 
     def getTableEntryCount(self, table, chrom="", rmin="", rmax="", restrict="", distinct=False, startField="start"):
-        """ returns the number of row in the uniqs table.
+        """ returns the number of row in the specified table.
         """
         whereClause = []
         count = 0
@@ -741,10 +727,7 @@ class ReadDataset():
         else:
             whereQuery = ""
 
-        if self.memBacked:
-            sql = self.memcon.cursor()
-        else:
-            sql = self.dbcon.cursor()
+        sql = self.getSqlCursor()
 
         if distinct:
             sql.execute("select count(distinct chrom+%s+sense) from %s %s" % (startField, table, whereQuery))
@@ -764,6 +747,7 @@ class ReadDataset():
     def getSplicesCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
         """ returns the number of row in the splices table.
         """
+        # TODO: if the rds type is DNA should this just return zero?
         return self.getTableEntryCount("splices", chrom, rmin, rmax, restrict, distinct, startField="startL")
 
 
@@ -802,11 +786,7 @@ class ReadDataset():
             limitPart = "LIMIT %d" % limit
 
         sqlQuery = "%s group by readID %s" % (selectPart, limitPart)
-        if self.memBacked:
-            sql = self.memcon.cursor()
-        else:
-            sql = self.dbcon.cursor()
-
+        sql = self.getSqlCursor()
         sql.execute(sqlQuery)
         result = sql.fetchall()
 
@@ -1018,6 +998,14 @@ class ReadDataset():
         self.dbcon.commit()
 
 
+    def insertMultisplices(self, valuesList):
+        """ inserts a list of (readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch)
+        into the multisplices table.
+        """
+        self.dbcon.executemany("insert into multisplices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", valuesList)
+        self.dbcon.commit()
+
+
     def flagReads(self, regionsList, uniqs=True, multi=False, splices=False, sense="both"):
         """ update reads on file database in a list region of regions for a chromosome to have a new flag.
             regionsList must have 4 fields per region of the form (flag, chrom, start, stop) or, with
@@ -1221,5 +1209,5 @@ class ReadDataset():
             destinationEntries.append((row["readID"], row["chrom"], int(row["startL"]), int(row["stopL"]), int(row["startR"]), int(row["stopR"]), row["sense"],
                                        row["weight"], row["flag"], row["mismatch"]))
 
-        self.memcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", destinationEntries)
+        self.memcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", destinationEntries)