first pass cleanup of cistematic/genomes; change bamPreprocessing
[erange.git] / geneMrnaCountsWeighted.py
1 try:
2     import psyco
3     psyco.full()
4 except:
5     print "psyco not running"
6
7 import sys
8 import optparse
9 from commoncode import getMergedRegions, getFeaturesByChromDict, getConfigParser, getConfigOption, getConfigBoolOption
10 import ReadDataset
11 from cistematic.genomes import Genome
12 from commoncode import getGeneInfoDict
13 from cistematic.core import chooseDB, cacheGeneDB, uncacheGeneDB
14
15 print "geneMrnaCountsWeighted: version 4.3"
16
17
18 def main(argv=None):
19     if not argv:
20         argv = sys.argv
21
22     usage = "usage: python %s genome rdsfile uniqcountfile outfile [options]"
23
24     parser = makeParser(usage)
25     (options, args) = parser.parse_args(argv[1:])
26
27     if len(args) < 4:
28         print usage
29         sys.exit(1)
30
31     genome = args[0]
32     hitfile =  args[1]
33     countfile = args[2]
34     outfilename = args[3]
35
36     geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, options.ignoreSense,
37                            options.withUniqs, options.withMulti,
38                            options.acceptfile, options.cachePages, options.doVerbose,
39                            options.extendGenome, options.replaceModels)
40
41
42 def makeParser(usage=""):
43     parser = optparse.OptionParser(usage=usage)
44     parser.add_option("--stranded", action="store_false", dest="ignoreSense")
45     parser.add_option("--uniq", action="store_true", dest="withUniqs")
46     parser.add_option("--multi", action="store_true", dest="withMulti")
47     parser.add_option("--accept", dest="acceptfile")
48     parser.add_option("--cache", type="int", dest="cachePages")
49     parser.add_option("--verbose", action="store_true", dest="doVerbose")
50     parser.add_option("--models", dest="extendGenome")
51     parser.add_option("--replacemodels", action="store_true", dest="replaceModels")
52
53     configParser = getConfigParser()
54     section = "geneMrnaCountsWeighted"
55     ignoreSense = getConfigBoolOption(configParser, section, "ignoreSense", True)
56     withUniqs = getConfigBoolOption(configParser, section, "withUniqs", False)
57     withMulti = getConfigBoolOption(configParser, section, "withMulti", False)
58     acceptfile = getConfigOption(configParser, section, "acceptfile", None)
59     cachePages = getConfigOption(configParser, section, "cachePages", None)
60     doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
61     extendGenome = getConfigOption(configParser, section, "extendGenome", "")
62     replaceModels = getConfigBoolOption(configParser, section, "replaceModels", False)
63
64     parser.set_defaults(ignoreSense=ignoreSense, withUniqs=withUniqs, withMulti=withMulti,
65                         acceptfile=acceptfile, cachePages=cachePages, doVerbose=doVerbose, extendGenome=extendGenome,
66                         replaceModels=replaceModels)
67
68     return parser
69
70
71 #TODO: Reported user performance issue. Long run times in conditions:
72 #    small number of reads ~40-50M
73 #    all features on single chromosome
74 #
75 #    User states has been a long time problem.
76
77 def geneMrnaCountsWeighted(genome, hitfile, countfile, outfilename, ignoreSense=True,
78                            withUniqs=False, withMulti=False, acceptfile=None,
79                            cachePages=None, doVerbose=False, extendGenome="", replaceModels=False):
80
81     if (not withUniqs and not withMulti) or (withUniqs and withMulti):
82         print "must have either one of -uniq or -multi set. Exiting"
83         sys.exit(1)
84
85     if cachePages is not None:
86         cacheGeneDB(genome)
87         hg = Genome(genome, dbFile=chooseDB(genome), inRAM=True)
88         print "%s cached" % genome
89         doCache = True
90     else:
91         doCache = False
92         cachePages = 0
93         hg = Genome(genome, inRAM=True)
94
95     if extendGenome:
96         if replaceModels:
97             print "will replace gene models with %s" % extendGenome
98         else:
99             print "will extend gene models with %s" % extendGenome
100
101         hg.extendFeatures(extendGenome, replace=replaceModels)
102
103     hitRDS = ReadDataset.ReadDataset(hitfile, verbose=doVerbose, cache=doCache)
104     if cachePages > hitRDS.getDefaultCacheSize():
105         hitRDS.setDBcache(cachePages)
106
107     allGIDs = set(hg.allGIDs())
108     if acceptfile is not None:
109         regionDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=doVerbose)
110         for chrom in regionDict:
111             for region in regionDict[chrom]:
112                 allGIDs.add(region.label)
113     else:
114         regionDict = {}
115
116     featuresByChromDict = getFeaturesByChromDict(hg, regionDict)
117
118     gidReadDict = {}
119     read2GidDict = {}
120     for gid in allGIDs:
121         gidReadDict[gid] = []
122
123     index = 0
124     if withMulti and not withUniqs:
125         chromList = hitRDS.getChromosomes(table="multi", fullChrom=False)
126     else:
127         chromList = hitRDS.getChromosomes(fullChrom=False)
128
129     readlen = hitRDS.getReadSize()
130     for chromosome in chromList:
131         if doNotProcessChromosome(chromosome, featuresByChromDict.keys()):
132             continue
133
134         print "\n%s " % chromosome,
135         fullchrom = "chr%s" % chromosome
136         hitDict = hitRDS.getReadsDict(noSense=ignoreSense, fullChrom=True, chrom=fullchrom, withID=True, doUniqs=withUniqs, doMulti=withMulti)
137         featureList = featuresByChromDict[chromosome]
138
139         readGidList, totalProcessedReads = getReadGIDs(hitDict, fullchrom, featureList, readlen, index)
140         index = totalProcessedReads
141         for (tagReadID, gid) in readGidList:
142             try:
143                 gidReadDict[gid].append(tagReadID)
144                 if tagReadID in read2GidDict:
145                     read2GidDict[tagReadID].add(gid)
146                 else:
147                     read2GidDict[tagReadID] = set([gid])
148             except KeyError:
149                 print "gid %s not in gidReadDict" % gid
150
151     writeCountsToFile(outfilename, countfile, allGIDs, hg, gidReadDict, read2GidDict, doVerbose, doCache)
152     if doCache:
153         uncacheGeneDB(genome)
154
155
156 def doNotProcessChromosome(chromosome, chromosomeList):
157     return chromosome not in chromosomeList
158
159
160 def getReadGIDs(hitDict, fullchrom, featList, readlen, index):
161
162     startFeature = 0
163     readGidList = []
164     ignoreSense = True
165     for read in hitDict[fullchrom]:
166         tagStart = read["start"]
167         tagReadID = read["readID"]
168         if read.has_key("sense"):
169             tagSense = read["sense"]
170             ignoreSense = False
171
172         index += 1
173         if index % 100000 == 0:
174             print "read %d" % index,
175
176         stopPoint = tagStart + readlen
177         if startFeature < 0:
178             startFeature = 0
179
180         for (start, stop, gid, sense, ftype) in featList[startFeature:]:
181             if tagStart > stop:
182                 startFeature += 1
183                 continue
184
185             if start > stopPoint:
186                 startFeature -= 100
187                 break
188
189             if not ignoreSense:
190                 if sense == "R":
191                     sense = "-"
192                 else:
193                     sense = "+"
194
195             if start <= tagStart <= stop and (ignoreSense or tagSense == sense):
196                 readGidList.append((tagReadID, gid))
197                 stopPoint = stop
198
199     return readGidList, index
200
201
202 def writeCountsToFile(outFilename, countFilename, allGIDs, genome, gidReadDict, read2GidDict, doVerbose=False, doCache=False):
203
204     uniqueCountDict = {}
205     uniquecounts = open(countFilename)
206     for line in uniquecounts:
207         fields = line.strip().split()
208         # add a pseudo-count here to ease calculations below
209         #TODO: figure out why this was done in prior implementation...
210         uniqueCountDict[fields[0]] = float(fields[-1]) + 1
211
212     uniquecounts.close()
213
214     genomeName = genome.genome
215     geneinfoDict = getGeneInfoDict(genomeName, cache=doCache)
216     geneannotDict = genome.allAnnotInfo()
217     outfile = open(outFilename, "w")
218     for gid in allGIDs:
219         symbol = getGeneSymbol(gid, genomeName, geneinfoDict, geneannotDict)
220         tagCount = getTagCount(uniqueCountDict, gid, gidReadDict, read2GidDict)
221         if doVerbose:
222             print "%s %s %f" % (gid, symbol, tagCount)
223
224         outfile.write("%s\t%s\t%d\n" % (gid, symbol, tagCount))
225
226     outfile.close()
227
228
229 def getGeneSymbol(gid, genomeName, geneinfoDict, geneannotDict):
230     if "FAR" not in gid:
231         symbol = "LOC%s" % gid
232         geneinfo = ""
233         try:
234             geneinfo = geneinfoDict[gid]
235             if genomeName == "celegans":
236                 symbol = geneinfo[0][1]
237             else:
238                 symbol = geneinfo[0][0]
239         except (KeyError, IndexError):
240             try:
241                 symbol = geneannotDict[(genomeName, gid)][0]
242             except (KeyError, IndexError):
243                 symbol = "LOC%s" % gid
244     else:
245         symbol = gid
246
247     return symbol
248
249
250 def getTagCount(uniqueCountDict, gid, gidReadDict, read2GidDict):
251     tagCount = 0.
252     for readID in gidReadDict[gid]:
253         try:
254             tagValue = uniqueCountDict[gid]
255         except KeyError:
256             tagValue = 1
257
258         tagDenom = 0.
259         for relatedGID in read2GidDict[readID]:
260             try:
261                 tagDenom += uniqueCountDict[relatedGID]
262             except KeyError:
263                 tagDenom += 1
264
265         try:
266             tagCount += tagValue / tagDenom
267         except ZeroDivisionError:
268             pass
269
270     return tagCount
271
272
273 if __name__ == "__main__":
274     main(sys.argv)