Release version for Erange 4.0a
[erange.git] / ReadDataset.py
index 5ff60e2e6954c862888c5807495ef49142e14fa3..56e28a8995b734735ace06d1d990ac9b46753580 100644 (file)
@@ -153,42 +153,42 @@ class ReadDataset():
             self.cachedDB = ""
 
 
             self.cachedDB = ""
 
 
-    def attachDB(self, filename, asname):
+    def attachDB(self, filename, dbName):
         """ attach another database file to the readDataset.
         """
         """ attach another database file to the readDataset.
         """
-        stmt = "attach '%s' as %s" % (filename, asname)
+        stmt = "attach '%s' as %s" % (filename, dbName)
         self.execute(stmt)
 
 
         self.execute(stmt)
 
 
-    def detachDB(self, asname):
+    def detachDB(self, dbName):
         """ detach a database file to the readDataset.
         """
         """ detach a database file to the readDataset.
         """
-        stmt = "detach %s" % (asname)
+        stmt = "detach %s" % (dbName)
         self.execute(stmt)
 
 
         self.execute(stmt)
 
 
-    def importFromDB(self, asname, table, ascolumns="*", destcolumns="", flagged=""):
+    def importFromDB(self, dbName, table, ascolumns="*", destcolumns="", flagged=""):
         """ import into current RDS the table (with columns destcolumns,
             with default all columns) from the database file asname,
             using the column specification of ascolumns (default all).
         """
         """ import into current RDS the table (with columns destcolumns,
             with default all columns) from the database file asname,
             using the column specification of ascolumns (default all).
         """
-        stmt = "insert into %s %s select %s from %s.%s" % (table, destcolumns, ascolumns, asname, table)
+        stmt = "insert into %s %s select %s from %s.%s" % (table, destcolumns, ascolumns, dbName, table)
         if flagged != "":
             stmt += " where flag = '%s' " % flagged
 
         self.executeCommit(stmt)
 
 
         if flagged != "":
             stmt += " where flag = '%s' " % flagged
 
         self.executeCommit(stmt)
 
 
-    def getTables(self, asname=""):
+    def getTables(self, dbName=""):
         """ get a list of table names in a particular database file.
         """
         resultList = []
         sql = self.getSqlCursor()
 
         """ get a list of table names in a particular database file.
         """
         resultList = []
         sql = self.getSqlCursor()
 
-        if asname != "":
-            asname += "."
+        if dbName != "":
+            dbName = "%s." % dbName
 
 
-        stmt = "select name from %ssqlite_master where type='table'" % asname
+        stmt = "select name from %ssqlite_master where type='table'" % dbName
         sql.execute(stmt)
         results = sql.fetchall()
 
         sql.execute(stmt)
         results = sql.fetchall()
 
@@ -291,11 +291,15 @@ class ReadDataset():
         if "readsize" not in metadata:
             raise ReadDatasetError("no readsize parameter defined")
         else:
         if "readsize" not in metadata:
             raise ReadDatasetError("no readsize parameter defined")
         else:
-            mysize = metadata["readsize"]
-            if "import" in mysize:
-                mysize = mysize.split()[0]
+            readSize = metadata["readsize"]
+            if "import" in readSize:
+                readSize = readSize.split()[0]
 
 
-            return int(mysize)
+            readSize = int(readSize)
+            if readSize < 0:
+                raise ReadDatasetError("readsize is negative")
+
+            return readSize
 
 
     def getDefaultCacheSize(self):
 
 
     def getDefaultCacheSize(self):
@@ -317,11 +321,9 @@ class ReadDataset():
                 if row["chrom"] not in results:
                     results.append(row["chrom"])
             else:
                 if row["chrom"] not in results:
                     results.append(row["chrom"])
             else:
-                if  len(row["chrom"][3:].strip()) < 1:
-                    continue
-
-                if row["chrom"][3:] not in results:
-                    results.append(row["chrom"][3:])
+                shortName = row["chrom"][3:]
+                if  len(shortName.strip()) > 0 and shortName not in results:
+                    results.append(shortName)
 
         results.sort()
 
 
         results.sort()
 
@@ -333,32 +335,17 @@ class ReadDataset():
         """ returns the maximum coordinate for reads on a given chromosome.
         """
         maxCoord = 0
         """ returns the maximum coordinate for reads on a given chromosome.
         """
         maxCoord = 0
-        sql = self.getSqlCursor()
 
         if doUniqs:
 
         if doUniqs:
-            try:
-                sql.execute("select max(start) from uniqs where chrom = '%s'" % chrom)
-                maxCoord = int(sql.fetchall()[0][0])
-            except:
-                print "couldn't retrieve coordMax for chromosome %s" % chrom
+            maxCoord = self.getMaxStartCoordinateInTable(chrom, "uniqs")
 
         if doSplices:
 
         if doSplices:
-            sql.execute("select max(startR) from splices where chrom = '%s'" % chrom)
-            try:
-                spliceMax = int(sql.fetchall()[0][0])
-                if spliceMax > maxCoord:
-                    maxCoord = spliceMax
-            except:
-                pass
+            spliceMax = self.getMaxStartCoordinateInTable(chrom, "splices", startField="startR")
+            maxCoord = max(spliceMax, maxCoord)
 
         if doMulti:
 
         if doMulti:
-            sql.execute("select max(start) from multi where chrom = '%s'" % chrom)
-            try:
-                multiMax = int(sql.fetchall()[0][0])
-                if multiMax > maxCoord:
-                    maxCoord = multiMax
-            except:
-                pass
+            multiMax = self.getMaxStartCoordinateInTable(chrom, "multi")
+            maxCoord = max(multiMax, maxCoord)
 
         if verbose:
             print "%s maxCoord: %d" % (chrom, maxCoord)
 
         if verbose:
             print "%s maxCoord: %d" % (chrom, maxCoord)
@@ -366,6 +353,19 @@ class ReadDataset():
         return maxCoord
 
 
         return maxCoord
 
 
+    def getMaxStartCoordinateInTable(self, chrom, table, startField="start"):
+        maxCoord = 0
+        sqlStatement = "select max(%s) from %s where chrom = '%s'" % (startField, table, chrom)
+        sql = self.getSqlCursor()
+        try:
+            sql.execute(sqlStatement)
+            maxCoord = int(sql.fetchall()[0][0])
+        except:
+            print "couldn't retrieve coordMax for chromosome %s" % chrom
+
+        return maxCoord
+
+
     def getReadsDict(self, bothEnds=False, noSense=False, fullChrom=False, chrom="",
                      flag="", withWeight=False, withFlag=False, withMismatch=False, withID=False,
                      withChrom=False, withPairID=False, doUniqs=True, doMulti=False, findallOptimize=False,
     def getReadsDict(self, bothEnds=False, noSense=False, fullChrom=False, chrom="",
                      flag="", withWeight=False, withFlag=False, withMismatch=False, withID=False,
                      withChrom=False, withPairID=False, doUniqs=True, doMulti=False, findallOptimize=False,
@@ -378,67 +378,14 @@ class ReadDataset():
         
         Need to rethink original design 1: Cannot have pairID without exporting as a readIDDict
         """
         
         Need to rethink original design 1: Cannot have pairID without exporting as a readIDDict
         """
-        whereClause = []
-        resultsDict = {}
-
-        if chrom != "" and chrom != self.memChrom:
-            whereClause.append("chrom = '%s'" % chrom)
-
-        if flag != "":
-            if flagLike:
-                flagLikeClause = string.join(['flag LIKE "%', flag, '%"'], "")
-                whereClause.append(flagLikeClause)
-            else:
-                whereClause.append("flag = '%s'" % flag)
-
-        if start > -1:
-            whereClause.append("start > %d" % start)
-
-        if stop > -1:
-            whereClause.append("stop < %d" % stop)
-
-        if len(readLike) > 0:
-            readIDClause = string.join(["readID LIKE  '", readLike, "%'"], "")
-            whereClause.append(readIDClause)
-
-        if hasMismatch:
-            whereClause.append("mismatch != ''")
-
-        if strand in ["+", "-"]:
-            whereClause.append("sense = '%s'" % strand)
-
-        if len(whereClause) > 0:
-            whereStatement = string.join(whereClause, " and ")
-            whereQuery = "where %s" % whereStatement
-        else:
-            whereQuery = ""
 
 
-        groupBy = []
+        whereQuery = self.getReadWhereQuery(chrom, flag, flagLike, start, stop, hasMismatch, strand, readLike)
         if findallOptimize:
         if findallOptimize:
-            selectClause = ["select start, sense, sum(weight)"]
-            groupBy = ["GROUP BY start, sense"]
+            selectQuery = "select start, sense, sum(weight)"
         else:
         else:
-            selectClause = ["select ID, chrom, start, readID"]
-            if bothEnds:
-                selectClause.append("stop")
-
-            if not noSense:
-                selectClause.append("sense")
-
-            if withWeight:
-                selectClause.append("weight")
-
-            if withFlag:
-                selectClause.append("flag")
+            selectQuery = self.getReadSelectQuery("select ID, chrom, start, readID", noSense, withWeight, withFlag, withMismatch, bothEnds)
 
 
-            if withMismatch:
-                selectClause.append("mismatch")
-
-        if limit > 0 and not combine5p:
-            groupBy.append("LIMIT %d" % limit)
-
-        selectQuery = string.join(selectClause, ",")
-        groupQuery = string.join(groupBy)
+        groupQuery = self.getReadGroupQuery(findallOptimize, limit, combine5p)
         if doUniqs:
             stmt = [selectQuery, "from uniqs", whereQuery, groupQuery]
             if doMulti:
         if doUniqs:
             stmt = [selectQuery, "from uniqs", whereQuery, groupQuery]
             if doMulti:
@@ -498,6 +445,7 @@ class ReadDataset():
         sqlQuery = string.join(stmt)
         sql.execute(sqlQuery)
 
         sqlQuery = string.join(stmt)
         sql.execute(sqlQuery)
 
+        resultsDict = {}
         if findallOptimize:
             resultsDict[chrom] = [{"start": int(row[0]), "sense": row[1], "weight": float(row[2])} for row in sql]
             if self.memBacked:
         if findallOptimize:
             resultsDict[chrom] = [{"start": int(row[0]), "sense": row[1], "weight": float(row[2])} for row in sql]
             if self.memBacked:
@@ -562,20 +510,17 @@ class ReadDataset():
         return resultsDict
 
 
         return resultsDict
 
 
-    def getSplicesDict(self, noSense=False, fullChrom=False, chrom="",
-                       flag="", withWeight=False, withFlag=False, withMismatch=False,
-                       withID=False, withChrom=False, withPairID=False, readIDDict=False,
-                       splitRead=False, hasMismatch=False, flagLike=False, start=-1,
-                       stop=-1, strand=""):
-        """ returns a dictionary of spliced reads in a variety of
-        formats and which can be restricted by chromosome or custom-flag.
-        Returns unique spliced reads for now.
-        """
-        whereClause = []
-        resultsDict = {}
+    def getReadWhereQuery(self, chrom, flag, flagLike, start, stop, hasMismatch, strand, readLike="", splice=False):
+        if splice:
+            startText = "startL"
+            stopText = "stopR"
+        else:
+            startText = "start"
+            stopText = "stop"
 
 
+        whereClause = []
         if chrom != "" and chrom != self.memChrom:
         if chrom != "" and chrom != self.memChrom:
-            whereClause = ["chrom = '%s'" % chrom]
+            whereClause.append("chrom = '%s'" % chrom)
 
         if flag != "":
             if flagLike:
 
         if flag != "":
             if flagLike:
@@ -584,25 +529,37 @@ class ReadDataset():
             else:
                 whereClause.append("flag = '%s'" % flag)
 
             else:
                 whereClause.append("flag = '%s'" % flag)
 
+        if start > -1:
+            whereClause.append("%s > %d" % (startText, start))
+
+        if stop > -1:
+            whereClause.append("%s < %d" % (stopText, stop))
+
+        if len(readLike) > 0:
+            readIDClause = string.join(["readID LIKE  '", readLike, "%'"], "")
+            whereClause.append(readIDClause)
+
         if hasMismatch:
             whereClause.append("mismatch != ''")
 
         if hasMismatch:
             whereClause.append("mismatch != ''")
 
-        if strand != "":
+        if strand in ["+", "-"]:
             whereClause.append("sense = '%s'" % strand)
 
             whereClause.append("sense = '%s'" % strand)
 
-        if start > -1:
-            whereClause.append("startL > %d" % start)
-
-        if stop > -1:
-            whereClause.append("stopR < %d" % stop)
-
         if len(whereClause) > 0:
             whereStatement = string.join(whereClause, " and ")
             whereQuery = "where %s" % whereStatement
         else:
             whereQuery = ""
 
         if len(whereClause) > 0:
             whereStatement = string.join(whereClause, " and ")
             whereQuery = "where %s" % whereStatement
         else:
             whereQuery = ""
 
-        selectClause = ["select ID, chrom, startL, stopL, startR, stopR, readID"]
+        return whereQuery
+
+
+    def getReadSelectQuery(self, baseSelect, noSense, withWeight, withFlag, withMismatch, bothEnds=False):
+
+        selectClause = [baseSelect]
+        if bothEnds:
+            selectClause.append("stop")
+
         if not noSense:
             selectClause.append("sense")
 
         if not noSense:
             selectClause.append("sense")
 
@@ -615,7 +572,36 @@ class ReadDataset():
         if withMismatch:
             selectClause.append("mismatch")
 
         if withMismatch:
             selectClause.append("mismatch")
 
-        selectQuery = string.join(selectClause, " ,")
+        selectQuery = string.join(selectClause, ",")
+
+        return selectQuery
+
+
+    def getReadGroupQuery(self, findallOptimize, limit, combine5p):
+        groupBy = []
+        if findallOptimize:
+            groupBy = ["GROUP BY start, sense"]
+
+        if limit > 0 and not combine5p:
+            groupBy.append("LIMIT %d" % limit)
+
+        groupQuery = string.join(groupBy)
+
+        return groupQuery
+
+
+    def getSplicesDict(self, noSense=False, fullChrom=False, chrom="",
+                       flag="", withWeight=False, withFlag=False, withMismatch=False,
+                       withID=False, withChrom=False, withPairID=False, readIDDict=False,
+                       splitRead=False, hasMismatch=False, flagLike=False, start=-1,
+                       stop=-1, strand=""):
+        """ returns a dictionary of spliced reads in a variety of
+        formats and which can be restricted by chromosome or custom-flag.
+        Returns unique spliced reads for now.
+        """
+        whereQuery = self.getReadWhereQuery(chrom, flag, flagLike, start, stop, hasMismatch, strand, splice=True)
+        selectClause = "select ID, chrom, startL, stopL, startR, stopR, readID"
+        selectQuery = self.getReadSelectQuery(selectClause, noSense, withWeight, withFlag, withMismatch)
         if self.memBacked:
             sql = self.memcon.cursor()
         else:
         if self.memBacked:
             sql = self.memcon.cursor()
         else:
@@ -625,6 +611,7 @@ class ReadDataset():
         sql.execute(stmt)
         currentReadID = ""
         currentChrom = ""
         sql.execute(stmt)
         currentReadID = ""
         currentChrom = ""
+        resultsDict = {}
         for row in sql:
             pairID = 0
             readID = row["readID"]
         for row in sql:
             pairID = 0
             readID = row["readID"]
@@ -796,10 +783,6 @@ class ReadDataset():
         """ get readID's.
         """
         stmt = []
         """ get readID's.
         """
         stmt = []
-        limitPart = ""
-        if limit > 0:
-            limitPart = "LIMIT %d" % limit
-
         if uniqs:
             stmt.append("select readID from uniqs")
 
         if uniqs:
             stmt.append("select readID from uniqs")
 
@@ -814,6 +797,10 @@ class ReadDataset():
         else:
             selectPart = ""
 
         else:
             selectPart = ""
 
+        limitPart = ""
+        if limit > 0:
+            limitPart = "LIMIT %d" % limit
+
         sqlQuery = "%s group by readID %s" % (selectPart, limitPart)
         if self.memBacked:
             sql = self.memcon.cursor()
         sqlQuery = "%s group by readID %s" % (selectPart, limitPart)
         if self.memBacked:
             sql = self.memcon.cursor()
@@ -845,12 +832,11 @@ class ReadDataset():
                 print "getting mismatches from chromosome %s" % (achrom)
 
             snpDict[achrom] = []
                 print "getting mismatches from chromosome %s" % (achrom)
 
             snpDict[achrom] = []
-            hitDict = self.getReadsDict(fullChrom=True, chrom=achrom, withMismatch=True, hasMismatch=True)
             if useSplices and self.dataType == "RNA":
                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=achrom, withMismatch=True, readIDDict=True, hasMismatch=True)
                 spliceIDList = spliceDict.keys()
             if useSplices and self.dataType == "RNA":
                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=achrom, withMismatch=True, readIDDict=True, hasMismatch=True)
                 spliceIDList = spliceDict.keys()
-                for k in spliceIDList:
-                    spliceEntry = spliceDict[k][0]
+                for spliceID in spliceIDList:
+                    spliceEntry = spliceDict[spliceID][0]
                     startpos = spliceEntry["startL"]
                     lefthalf = spliceEntry["stopL"]
                     rightstart = spliceEntry["startR"]
                     startpos = spliceEntry["startL"]
                     lefthalf = spliceEntry["stopL"]
                     rightstart = spliceEntry["startR"]
@@ -881,6 +867,7 @@ class ReadDataset():
 
                         snpDict[achrom].append([startpos, change_at, change_base, change_from])
 
 
                         snpDict[achrom].append([startpos, change_at, change_base, change_from])
 
+            hitDict = self.getReadsDict(fullChrom=True, chrom=achrom, withMismatch=True, hasMismatch=True)
             if achrom not in hitDict.keys():
                 continue
 
             if achrom not in hitDict.keys():
                 continue
 
@@ -910,7 +897,7 @@ class ReadDataset():
 
 
     def getChromProfile(self, chromosome, cstart=-1, cstop=-1, useMulti=True,
 
 
     def getChromProfile(self, chromosome, cstart=-1, cstop=-1, useMulti=True,
-                        useSplices=False, normalizationFactor = 1.0, trackStrand=False,
+                        useSplices=False, normalizationFactor=1.0, trackStrand=False,
                         keepStrand="both", shiftValue=0):
         """return a profile of the chromosome as an array of per-base read coverage....
             keepStrand = 'both', 'plusOnly', or 'minusOnly'.
                         keepStrand="both", shiftValue=0):
         """return a profile of the chromosome as an array of per-base read coverage....
             keepStrand = 'both', 'plusOnly', or 'minusOnly'.
@@ -925,8 +912,8 @@ class ReadDataset():
         dataType = metadata["dataType"]
         scale = 1. / normalizationFactor
         shift = {}
         dataType = metadata["dataType"]
         scale = 1. / normalizationFactor
         shift = {}
-        shift['+'] = int(shiftValue)
-        shift['-'] = -1 * int(shiftValue)
+        shift["+"] = int(shiftValue)
+        shift["-"] = -1 * int(shiftValue)
 
         if cstop > 0:
             lastNT = self.getMaxCoordinate(chromosome, doMulti=useMulti, doSplices=useSplices) + readlen
 
         if cstop > 0:
             lastNT = self.getMaxCoordinate(chromosome, doMulti=useMulti, doSplices=useSplices) + readlen