development checkpoint
[erange.git] / ReadDataset.py
index ef80d657c721c837cbbc6c64d2a007d0fc6aa6e5..850a5ec602a7b6d7a9a22ac98c4bf3eef4b873e5 100644 (file)
@@ -1,25 +1,12 @@
-"""
-Created on Jul 1, 2010
-
-@author: sau
-"""
-
 import sqlite3 as sqlite
 import string
 import tempfile
 import shutil
 import os
 import sqlite3 as sqlite
 import string
 import tempfile
 import shutil
 import os
-from os import environ
 from array import array
 from array import array
-from commoncode import getReverseComplement
-
-if environ.get("CISTEMATIC_TEMP"):
-    cisTemp = environ.get("CISTEMATIC_TEMP")
-else:
-    cisTemp = "/tmp"
+from commoncode import getReverseComplement, getConfigParser, getConfigOption
 
 
-tempfile.tempdir = cisTemp
-currentRDSVersion = "1.1"
+currentRDSVersion = "2.0"
 
 
 class ReadDatasetError(Exception):
 
 
 class ReadDatasetError(Exception):
@@ -140,6 +127,9 @@ class ReadDataset():
     def cacheDB(self, filename):
         """ copy geneinfoDB to a local cache.
         """
     def cacheDB(self, filename):
         """ copy geneinfoDB to a local cache.
         """
+        configParser = getConfigParser()
+        cisTemp = getConfigOption(configParser, "general", "cistematic_temp", default="/tmp")
+        tempfile.tempdir = cisTemp
         self.cachedDBFile =  "%s.db" % tempfile.mktemp()
         shutil.copyfile(filename, self.cachedDBFile)
 
         self.cachedDBFile =  "%s.db" % tempfile.mktemp()
         shutil.copyfile(filename, self.cachedDBFile)
 
@@ -163,42 +153,42 @@ class ReadDataset():
             self.cachedDB = ""
 
 
             self.cachedDB = ""
 
 
-    def attachDB(self, filename, asname):
+    def attachDB(self, filename, dbName):
         """ attach another database file to the readDataset.
         """
         """ attach another database file to the readDataset.
         """
-        stmt = "attach '%s' as %s" % (filename, asname)
+        stmt = "attach '%s' as %s" % (filename, dbName)
         self.execute(stmt)
 
 
         self.execute(stmt)
 
 
-    def detachDB(self, asname):
+    def detachDB(self, dbName):
         """ detach a database file to the readDataset.
         """
         """ detach a database file to the readDataset.
         """
-        stmt = "detach %s" % (asname)
+        stmt = "detach %s" % (dbName)
         self.execute(stmt)
 
 
         self.execute(stmt)
 
 
-    def importFromDB(self, asname, table, ascolumns="*", destcolumns="", flagged=""):
+    def importFromDB(self, dbName, table, ascolumns="*", destcolumns="", flagged=""):
         """ import into current RDS the table (with columns destcolumns,
             with default all columns) from the database file asname,
             using the column specification of ascolumns (default all).
         """
         """ import into current RDS the table (with columns destcolumns,
             with default all columns) from the database file asname,
             using the column specification of ascolumns (default all).
         """
-        stmt = "insert into %s %s select %s from %s.%s" % (table, destcolumns, ascolumns, asname, table)
+        stmt = "insert into %s %s select %s from %s.%s" % (table, destcolumns, ascolumns, dbName, table)
         if flagged != "":
             stmt += " where flag = '%s' " % flagged
 
         self.executeCommit(stmt)
 
 
         if flagged != "":
             stmt += " where flag = '%s' " % flagged
 
         self.executeCommit(stmt)
 
 
-    def getTables(self, asname=""):
+    def getTables(self, dbName=""):
         """ get a list of table names in a particular database file.
         """
         resultList = []
         sql = self.getSqlCursor()
 
         """ get a list of table names in a particular database file.
         """
         resultList = []
         sql = self.getSqlCursor()
 
-        if asname != "":
-            asname += "."
+        if dbName != "":
+            dbName = "%s." % dbName
 
 
-        stmt = "select name from %ssqlite_master where type='table'" % asname
+        stmt = "select name from %ssqlite_master where type='table'" % dbName
         sql.execute(stmt)
         results = sql.fetchall()
 
         sql.execute(stmt)
         results = sql.fetchall()
 
@@ -301,11 +291,15 @@ class ReadDataset():
         if "readsize" not in metadata:
             raise ReadDatasetError("no readsize parameter defined")
         else:
         if "readsize" not in metadata:
             raise ReadDatasetError("no readsize parameter defined")
         else:
-            mysize = metadata["readsize"]
-            if "import" in mysize:
-                mysize = mysize.split()[0]
+            readSize = metadata["readsize"]
+            if "import" in readSize:
+                readSize = readSize.split()[0]
 
 
-            return int(mysize)
+            readSize = int(readSize)
+            if readSize < 0:
+                raise ReadDatasetError("readsize is negative")
+
+            return readSize
 
 
     def getDefaultCacheSize(self):
 
 
     def getDefaultCacheSize(self):
@@ -327,11 +321,9 @@ class ReadDataset():
                 if row["chrom"] not in results:
                     results.append(row["chrom"])
             else:
                 if row["chrom"] not in results:
                     results.append(row["chrom"])
             else:
-                if  len(row["chrom"][3:].strip()) < 1:
-                    continue
-
-                if row["chrom"][3:] not in results:
-                    results.append(row["chrom"][3:])
+                shortName = row["chrom"][3:]
+                if  len(shortName.strip()) > 0 and shortName not in results:
+                    results.append(shortName)
 
         results.sort()
 
 
         results.sort()
 
@@ -343,32 +335,17 @@ class ReadDataset():
         """ returns the maximum coordinate for reads on a given chromosome.
         """
         maxCoord = 0
         """ returns the maximum coordinate for reads on a given chromosome.
         """
         maxCoord = 0
-        sql = self.getSqlCursor()
 
         if doUniqs:
 
         if doUniqs:
-            try:
-                sql.execute("select max(start) from uniqs where chrom = '%s'" % chrom)
-                maxCoord = int(sql.fetchall()[0][0])
-            except:
-                print "couldn't retrieve coordMax for chromosome %s" % chrom
+            maxCoord = self.getMaxStartCoordinateInTable(chrom, "uniqs")
 
         if doSplices:
 
         if doSplices:
-            sql.execute("select max(startR) from splices where chrom = '%s'" % chrom)
-            try:
-                spliceMax = int(sql.fetchall()[0][0])
-                if spliceMax > maxCoord:
-                    maxCoord = spliceMax
-            except:
-                pass
+            spliceMax = self.getMaxStartCoordinateInTable(chrom, "splices", startField="startR")
+            maxCoord = max(spliceMax, maxCoord)
 
         if doMulti:
 
         if doMulti:
-            sql.execute("select max(start) from multi where chrom = '%s'" % chrom)
-            try:
-                multiMax = int(sql.fetchall()[0][0])
-                if multiMax > maxCoord:
-                    maxCoord = multiMax
-            except:
-                pass
+            multiMax = self.getMaxStartCoordinateInTable(chrom, "multi")
+            maxCoord = max(multiMax, maxCoord)
 
         if verbose:
             print "%s maxCoord: %d" % (chrom, maxCoord)
 
         if verbose:
             print "%s maxCoord: %d" % (chrom, maxCoord)
@@ -376,6 +353,19 @@ class ReadDataset():
         return maxCoord
 
 
         return maxCoord
 
 
+    def getMaxStartCoordinateInTable(self, chrom, table, startField="start"):
+        maxCoord = 0
+        sqlStatement = "select max(%s) from %s where chrom = '%s'" % (startField, table, chrom)
+        sql = self.getSqlCursor()
+        try:
+            sql.execute(sqlStatement)
+            maxCoord = int(sql.fetchall()[0][0])
+        except:
+            print "couldn't retrieve coordMax for chromosome %s" % chrom
+
+        return maxCoord
+
+
     def getReadsDict(self, bothEnds=False, noSense=False, fullChrom=False, chrom="",
                      flag="", withWeight=False, withFlag=False, withMismatch=False, withID=False,
                      withChrom=False, withPairID=False, doUniqs=True, doMulti=False, findallOptimize=False,
     def getReadsDict(self, bothEnds=False, noSense=False, fullChrom=False, chrom="",
                      flag="", withWeight=False, withFlag=False, withMismatch=False, withID=False,
                      withChrom=False, withPairID=False, doUniqs=True, doMulti=False, findallOptimize=False,
@@ -385,68 +375,17 @@ class ReadDataset():
         and which can be restricted by chromosome or custom-flag.
         Returns unique reads by default, but can return multireads
         with doMulti set to True.
         and which can be restricted by chromosome or custom-flag.
         Returns unique reads by default, but can return multireads
         with doMulti set to True.
+        
+        Need to rethink original design 1: Cannot have pairID without exporting as a readIDDict
         """
         """
-        whereClause = []
-        resultsDict = {}
-
-        if chrom != "" and chrom != self.memChrom:
-            whereClause.append("chrom = '%s'" % chrom)
-
-        if flag != "":
-            if flagLike:
-                flagLikeClause = string.join(['flag LIKE "%', flag, '%"'], "")
-                whereClause.append(flagLikeClause)
-            else:
-                whereClause.append("flag = '%s'" % flag)
-
-        if start > -1:
-            whereClause.append("start > %d" % start)
-
-        if stop > -1:
-            whereClause.append("stop < %d" % stop)
-
-        if len(readLike) > 0:
-            readIDClause = string.join(["readID LIKE  '", readLike, "%'"], "")
-            whereClause.append(readIDClause)
-
-        if hasMismatch:
-            whereClause.append("mismatch != ''")
-
-        if strand in ["+", "-"]:
-            whereClause.append("sense = '%s'" % strand)
-
-        if len(whereClause) > 0:
-            whereStatement = string.join(whereClause, " and ")
-            whereQuery = "where %s" % whereStatement
-        else:
-            whereQuery = ""
 
 
-        groupBy = []
+        whereQuery = self.getReadWhereQuery(chrom, flag, flagLike, start, stop, hasMismatch, strand, readLike)
         if findallOptimize:
         if findallOptimize:
-            selectClause = ["select start, sense, sum(weight)"]
-            groupBy = ["GROUP BY start, sense"]
+            selectQuery = "select start, sense, sum(weight)"
         else:
         else:
-            selectClause = ["select ID, chrom, start, readID"]
-            if bothEnds:
-                selectClause.append("stop")
+            selectQuery = self.getReadSelectQuery("select ID, chrom, start, readID", noSense, withWeight, withFlag, withMismatch, bothEnds)
 
 
-            if not noSense:
-                selectClause.append("sense")
-
-            if withWeight:
-                selectClause.append("weight")
-
-            if withFlag:
-                selectClause.append("flag")
-
-            if withMismatch:
-                selectClause.append("mismatch")
-
-        if limit > 0 and not combine5p:
-            groupBy.append("LIMIT %d" % limit)
-
-        selectQuery = string.join(selectClause, ",")
-        groupQuery = string.join(groupBy)
+        groupQuery = self.getReadGroupQuery(findallOptimize, limit, combine5p)
         if doUniqs:
             stmt = [selectQuery, "from uniqs", whereQuery, groupQuery]
             if doMulti:
         if doUniqs:
             stmt = [selectQuery, "from uniqs", whereQuery, groupQuery]
             if doMulti:
@@ -506,6 +445,7 @@ class ReadDataset():
         sqlQuery = string.join(stmt)
         sql.execute(sqlQuery)
 
         sqlQuery = string.join(stmt)
         sql.execute(sqlQuery)
 
+        resultsDict = {}
         if findallOptimize:
             resultsDict[chrom] = [{"start": int(row[0]), "sense": row[1], "weight": float(row[2])} for row in sql]
             if self.memBacked:
         if findallOptimize:
             resultsDict[chrom] = [{"start": int(row[0]), "sense": row[1], "weight": float(row[2])} for row in sql]
             if self.memBacked:
@@ -570,20 +510,17 @@ class ReadDataset():
         return resultsDict
 
 
         return resultsDict
 
 
-    def getSplicesDict(self, noSense=False, fullChrom=False, chrom="",
-                       flag="", withWeight=False, withFlag=False, withMismatch=False,
-                       withID=False, withChrom=False, withPairID=False, readIDDict=False,
-                       splitRead=False, hasMismatch=False, flagLike=False, start=-1,
-                       stop=-1, strand=""):
-        """ returns a dictionary of spliced reads in a variety of
-        formats and which can be restricted by chromosome or custom-flag.
-        Returns unique spliced reads for now.
-        """
-        whereClause = []
-        resultsDict = {}
+    def getReadWhereQuery(self, chrom, flag, flagLike, start, stop, hasMismatch, strand, readLike="", splice=False):
+        if splice:
+            startText = "startL"
+            stopText = "stopR"
+        else:
+            startText = "start"
+            stopText = "stop"
 
 
+        whereClause = []
         if chrom != "" and chrom != self.memChrom:
         if chrom != "" and chrom != self.memChrom:
-            whereClause = ["chrom = '%s'" % chrom]
+            whereClause.append("chrom = '%s'" % chrom)
 
         if flag != "":
             if flagLike:
 
         if flag != "":
             if flagLike:
@@ -592,25 +529,37 @@ class ReadDataset():
             else:
                 whereClause.append("flag = '%s'" % flag)
 
             else:
                 whereClause.append("flag = '%s'" % flag)
 
+        if start > -1:
+            whereClause.append("%s > %d" % (startText, start))
+
+        if stop > -1:
+            whereClause.append("%s < %d" % (stopText, stop))
+
+        if len(readLike) > 0:
+            readIDClause = string.join(["readID LIKE  '", readLike, "%'"], "")
+            whereClause.append(readIDClause)
+
         if hasMismatch:
             whereClause.append("mismatch != ''")
 
         if hasMismatch:
             whereClause.append("mismatch != ''")
 
-        if strand != "":
+        if strand in ["+", "-"]:
             whereClause.append("sense = '%s'" % strand)
 
             whereClause.append("sense = '%s'" % strand)
 
-        if start > -1:
-            whereClause.append("startL > %d" % start)
-
-        if stop > -1:
-            whereClause.append("stopR < %d" % stop)
-
         if len(whereClause) > 0:
             whereStatement = string.join(whereClause, " and ")
             whereQuery = "where %s" % whereStatement
         else:
             whereQuery = ""
 
         if len(whereClause) > 0:
             whereStatement = string.join(whereClause, " and ")
             whereQuery = "where %s" % whereStatement
         else:
             whereQuery = ""
 
-        selectClause = ["select ID, chrom, startL, stopL, startR, stopR, readID"]
+        return whereQuery
+
+
+    def getReadSelectQuery(self, baseSelect, noSense, withWeight, withFlag, withMismatch, bothEnds=False):
+
+        selectClause = [baseSelect]
+        if bothEnds:
+            selectClause.append("stop")
+
         if not noSense:
             selectClause.append("sense")
 
         if not noSense:
             selectClause.append("sense")
 
@@ -623,7 +572,36 @@ class ReadDataset():
         if withMismatch:
             selectClause.append("mismatch")
 
         if withMismatch:
             selectClause.append("mismatch")
 
-        selectQuery = string.join(selectClause, " ,")
+        selectQuery = string.join(selectClause, ",")
+
+        return selectQuery
+
+
+    def getReadGroupQuery(self, findallOptimize, limit, combine5p):
+        groupBy = []
+        if findallOptimize:
+            groupBy = ["GROUP BY start, sense"]
+
+        if limit > 0 and not combine5p:
+            groupBy.append("LIMIT %d" % limit)
+
+        groupQuery = string.join(groupBy)
+
+        return groupQuery
+
+
+    def getSplicesDict(self, noSense=False, fullChrom=False, chrom="",
+                       flag="", withWeight=False, withFlag=False, withMismatch=False,
+                       withID=False, withChrom=False, withPairID=False, readIDDict=False,
+                       splitRead=False, hasMismatch=False, flagLike=False, start=-1,
+                       stop=-1, strand=""):
+        """ returns a dictionary of spliced reads in a variety of
+        formats and which can be restricted by chromosome or custom-flag.
+        Returns unique spliced reads for now.
+        """
+        whereQuery = self.getReadWhereQuery(chrom, flag, flagLike, start, stop, hasMismatch, strand, splice=True)
+        selectClause = "select ID, chrom, startL, stopL, startR, stopR, readID"
+        selectQuery = self.getReadSelectQuery(selectClause, noSense, withWeight, withFlag, withMismatch)
         if self.memBacked:
             sql = self.memcon.cursor()
         else:
         if self.memBacked:
             sql = self.memcon.cursor()
         else:
@@ -633,6 +611,7 @@ class ReadDataset():
         sql.execute(stmt)
         currentReadID = ""
         currentChrom = ""
         sql.execute(stmt)
         currentReadID = ""
         currentChrom = ""
+        resultsDict = {}
         for row in sql:
             pairID = 0
             readID = row["readID"]
         for row in sql:
             pairID = 0
             readID = row["readID"]
@@ -785,6 +764,7 @@ class ReadDataset():
     def getSplicesCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
         """ returns the number of row in the splices table.
         """
     def getSplicesCount(self, chrom="", rmin="", rmax="", restrict="", distinct=False):
         """ returns the number of row in the splices table.
         """
+        # TODO: if the rds type is DNA should this just return zero?
         return self.getTableEntryCount("splices", chrom, rmin, rmax, restrict, distinct, startField="startL")
 
 
         return self.getTableEntryCount("splices", chrom, rmin, rmax, restrict, distinct, startField="startL")
 
 
@@ -804,10 +784,6 @@ class ReadDataset():
         """ get readID's.
         """
         stmt = []
         """ get readID's.
         """
         stmt = []
-        limitPart = ""
-        if limit > 0:
-            limitPart = "LIMIT %d" % limit
-
         if uniqs:
             stmt.append("select readID from uniqs")
 
         if uniqs:
             stmt.append("select readID from uniqs")
 
@@ -822,6 +798,10 @@ class ReadDataset():
         else:
             selectPart = ""
 
         else:
             selectPart = ""
 
+        limitPart = ""
+        if limit > 0:
+            limitPart = "LIMIT %d" % limit
+
         sqlQuery = "%s group by readID %s" % (selectPart, limitPart)
         if self.memBacked:
             sql = self.memcon.cursor()
         sqlQuery = "%s group by readID %s" % (selectPart, limitPart)
         if self.memBacked:
             sql = self.memcon.cursor()
@@ -853,12 +833,11 @@ class ReadDataset():
                 print "getting mismatches from chromosome %s" % (achrom)
 
             snpDict[achrom] = []
                 print "getting mismatches from chromosome %s" % (achrom)
 
             snpDict[achrom] = []
-            hitDict = self.getReadsDict(fullChrom=True, chrom=achrom, withMismatch=True, hasMismatch=True)
             if useSplices and self.dataType == "RNA":
                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=achrom, withMismatch=True, readIDDict=True, hasMismatch=True)
                 spliceIDList = spliceDict.keys()
             if useSplices and self.dataType == "RNA":
                 spliceDict = self.getSplicesDict(fullChrom=True, chrom=achrom, withMismatch=True, readIDDict=True, hasMismatch=True)
                 spliceIDList = spliceDict.keys()
-                for k in spliceIDList:
-                    spliceEntry = spliceDict[k][0]
+                for spliceID in spliceIDList:
+                    spliceEntry = spliceDict[spliceID][0]
                     startpos = spliceEntry["startL"]
                     lefthalf = spliceEntry["stopL"]
                     rightstart = spliceEntry["startR"]
                     startpos = spliceEntry["startL"]
                     lefthalf = spliceEntry["stopL"]
                     rightstart = spliceEntry["startR"]
@@ -889,6 +868,7 @@ class ReadDataset():
 
                         snpDict[achrom].append([startpos, change_at, change_base, change_from])
 
 
                         snpDict[achrom].append([startpos, change_at, change_base, change_from])
 
+            hitDict = self.getReadsDict(fullChrom=True, chrom=achrom, withMismatch=True, hasMismatch=True)
             if achrom not in hitDict.keys():
                 continue
 
             if achrom not in hitDict.keys():
                 continue
 
@@ -918,7 +898,7 @@ class ReadDataset():
 
 
     def getChromProfile(self, chromosome, cstart=-1, cstop=-1, useMulti=True,
 
 
     def getChromProfile(self, chromosome, cstart=-1, cstop=-1, useMulti=True,
-                        useSplices=False, normalizationFactor = 1.0, trackStrand=False,
+                        useSplices=False, normalizationFactor=1.0, trackStrand=False,
                         keepStrand="both", shiftValue=0):
         """return a profile of the chromosome as an array of per-base read coverage....
             keepStrand = 'both', 'plusOnly', or 'minusOnly'.
                         keepStrand="both", shiftValue=0):
         """return a profile of the chromosome as an array of per-base read coverage....
             keepStrand = 'both', 'plusOnly', or 'minusOnly'.
@@ -933,8 +913,8 @@ class ReadDataset():
         dataType = metadata["dataType"]
         scale = 1. / normalizationFactor
         shift = {}
         dataType = metadata["dataType"]
         scale = 1. / normalizationFactor
         shift = {}
-        shift['+'] = int(shiftValue)
-        shift['-'] = -1 * int(shiftValue)
+        shift["+"] = int(shiftValue)
+        shift["-"] = -1 * int(shiftValue)
 
         if cstop > 0:
             lastNT = self.getMaxCoordinate(chromosome, doMulti=useMulti, doSplices=useSplices) + readlen
 
         if cstop > 0:
             lastNT = self.getMaxCoordinate(chromosome, doMulti=useMulti, doSplices=useSplices) + readlen
@@ -1242,5 +1222,5 @@ class ReadDataset():
             destinationEntries.append((row["readID"], row["chrom"], int(row["startL"]), int(row["stopL"]), int(row["startR"]), int(row["stopR"]), row["sense"],
                                        row["weight"], row["flag"], row["mismatch"]))
 
             destinationEntries.append((row["readID"], row["chrom"], int(row["startL"]), int(row["stopL"]), int(row["startR"]), int(row["stopR"]), row["sense"],
                                        row["weight"], row["flag"], row["mismatch"]))
 
-        self.memcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?)", destinationEntries)
+        self.memcon.executemany("insert into splices(ID, readID, chrom, startL, stopL, startR, stopR, sense, weight, flag, mismatch) values (NULL,?,?,?,?,?,?,?,?,?,?)", destinationEntries)