first pass cleanup of cistematic/genomes; change bamPreprocessing
[erange.git] / geneMrnaCountsWeighted.py
1 try:
2     import psyco
3     psyco.full()
4 except:
5     print "psyco not running"
6
7 import sys
8 import optparse
9 import pysam
10 from commoncode import getMergedRegions, getFeaturesByChromDict, getConfigParser, getConfigOption, getConfigBoolOption, getHeaderComment, addPairIDtoReadID, isSpliceEntry
11 from cistematic.genomes import Genome
12 from commoncode import getGeneInfoDict
13 from cistematic.core import chooseDB, cacheGeneDB, uncacheGeneDB
14
15 print "geneMrnaCountsWeighted: version 5.0"
16
17
18 def main(argv=None):
19     if not argv:
20         argv = sys.argv
21
22     usage = "usage: python %s genome rdsfile uniqcountfile outfile [options]"
23
24     parser = makeParser(usage)
25     (options, args) = parser.parse_args(argv[1:])
26
27     if len(args) < 4:
28         print usage
29         sys.exit(1)
30
31     genome = args[0]
32     hitfile =  args[1]
33     countfile = args[2]
34     outfilename = args[3]
35     bamfile = pysam.Samfile(hitfile, "rb")
36
37     geneMrnaCountsWeighted(genome, bamfile, countfile, outfilename, options.ignoreSense,
38                            options.withUniqs, options.withMulti,
39                            options.acceptfile, options.cachePages, options.doVerbose,
40                            options.extendGenome, options.replaceModels)
41
42
43 def makeParser(usage=""):
44     parser = optparse.OptionParser(usage=usage)
45     parser.add_option("--stranded", action="store_false", dest="ignoreSense")
46     parser.add_option("--uniq", action="store_true", dest="withUniqs")
47     parser.add_option("--multi", action="store_true", dest="withMulti")
48     parser.add_option("--accept", dest="acceptfile")
49     parser.add_option("--cache", type="int", dest="cachePages")
50     parser.add_option("--verbose", action="store_true", dest="doVerbose")
51     parser.add_option("--models", dest="extendGenome")
52     parser.add_option("--replacemodels", action="store_true", dest="replaceModels")
53
54     configParser = getConfigParser()
55     section = "geneMrnaCountsWeighted"
56     ignoreSense = getConfigBoolOption(configParser, section, "ignoreSense", True)
57     withUniqs = getConfigBoolOption(configParser, section, "withUniqs", False)
58     withMulti = getConfigBoolOption(configParser, section, "withMulti", False)
59     acceptfile = getConfigOption(configParser, section, "acceptfile", None)
60     cachePages = getConfigOption(configParser, section, "cachePages", None)
61     doVerbose = getConfigBoolOption(configParser, section, "doVerbose", False)
62     extendGenome = getConfigOption(configParser, section, "extendGenome", "")
63     replaceModels = getConfigBoolOption(configParser, section, "replaceModels", False)
64
65     parser.set_defaults(ignoreSense=ignoreSense, withUniqs=withUniqs, withMulti=withMulti,
66                         acceptfile=acceptfile, cachePages=cachePages, doVerbose=doVerbose, extendGenome=extendGenome,
67                         replaceModels=replaceModels)
68
69     return parser
70
71
72 def geneMrnaCountsWeighted(genome, bamfile, countfile, outfilename, ignoreSense=True,
73                            withUniqs=False, withMulti=False, acceptfile=None,
74                            cachePages=None, doVerbose=False, extendGenome="", replaceModels=False):
75
76     if (not withUniqs and not withMulti) or (withUniqs and withMulti):
77         print "must have either one of -uniq or -multi set. Exiting"
78         sys.exit(1)
79
80     if cachePages is not None:
81         cacheGeneDB(genome)
82         hg = Genome(genome, dbFile=chooseDB(genome), inRAM=True)
83         print "%s cached" % genome
84         doCache = True
85     else:
86         doCache = False
87         hg = Genome(genome, inRAM=True)
88
89     if extendGenome:
90         if replaceModels:
91             print "will replace gene models with %s" % extendGenome
92         else:
93             print "will extend gene models with %s" % extendGenome
94
95         hg.extendFeatures(extendGenome, replace=replaceModels)
96
97     allGIDs = set(hg.allGIDs())
98     if acceptfile is not None:
99         regionDict = getMergedRegions(acceptfile, maxDist=0, keepLabel=True, verbose=doVerbose)
100         for chrom in regionDict:
101             for region in regionDict[chrom]:
102                 allGIDs.add(region.label)
103     else:
104         regionDict = {}
105
106     featuresByChromDict = getFeaturesByChromDict(hg, regionDict)
107
108     gidReadDict = {}
109     read2GidDict = {}
110     for gid in allGIDs:
111         gidReadDict[gid] = []
112
113     index = 0
114     chromosomeList = [chrom for chrom in bamfile.references if chrom != "chrM"]
115     for chromosome in chromosomeList:
116         if doNotProcessChromosome(chromosome, featuresByChromDict.keys()):
117             continue
118
119         print "\n%s " % chromosome,
120         featureList = featuresByChromDict[chromosome]
121
122         readGidList, totalProcessedReads = getReadGIDs(bamfile, chromosome, featureList, index, withUniqs, withMulti, ignoreSense=ignoreSense)
123         index = totalProcessedReads
124         for (tagReadID, gid) in readGidList:
125             try:
126                 gidReadDict[gid].append(tagReadID)
127                 if tagReadID in read2GidDict:
128                     read2GidDict[tagReadID].add(gid)
129                 else:
130                     read2GidDict[tagReadID] = set([gid])
131             except KeyError:
132                 print "gid %s not in gidReadDict" % gid
133
134     writeCountsToFile(outfilename, countfile, allGIDs, hg, gidReadDict, read2GidDict, doVerbose, doCache)
135     if doCache:
136         uncacheGeneDB(genome)
137
138
139 def doNotProcessChromosome(chromosome, chromosomeList):
140     return chromosome not in chromosomeList
141
142
143 def getReadGIDs(bamfile, chromosome, featList, index, doUniqs, doMulti, ignoreSense=True):
144
145     startFeature = 0
146     readGidList = []
147     readlen = getHeaderComment(bamfile.header, "ReadLength")
148     for alignedread in bamfile.fetch(chromosome):
149         if doNotProcessRead(alignedread, doUniqs, doMulti):
150             continue
151
152         tagStart = alignedread.pos
153         tagReadID = addPairIDtoReadID(alignedread)
154         tagSense = ""
155         if not ignoreSense:
156             if alignedread.is_reverse:
157                 tagSense = "-"
158             else:
159                 tagSense = "+"
160
161         index += 1
162         if index % 100000 == 0:
163             print "read %d" % index,
164
165         stopPoint = tagStart + readlen
166         if startFeature < 0:
167             startFeature = 0
168
169         for (start, stop, gid, sense, ftype) in featList[startFeature:]:
170             if tagStart > stop:
171                 startFeature += 1
172                 continue
173
174             if start > stopPoint:
175                 startFeature -= 100
176                 break
177
178             if not ignoreSense:
179                 if sense == "R":
180                     sense = "-"
181                 else:
182                     sense = "+"
183
184             if start <= tagStart <= stop and (ignoreSense or tagSense == sense):
185                 readGidList.append((tagReadID, gid))
186                 stopPoint = stop
187
188     return readGidList, index
189
190
191 def doNotProcessRead(read, doUniqs, doMulti):
192     doNotProcess = False
193     if isSpliceEntry(read.cigar):
194         doNotProcess = True
195     elif doUniqs and read.opt('NH') > 1:
196         doNotProcess = True
197     elif doMulti and read.opt('NH') == 1:
198         doNotProcess = True
199
200     return doNotProcess
201
202
203 def writeCountsToFile(outFilename, countFilename, allGIDs, genome, gidReadDict, read2GidDict, doVerbose=False, doCache=False):
204
205     uniqueCountDict = {}
206     uniquecounts = open(countFilename)
207     for line in uniquecounts:
208         fields = line.strip().split()
209         # add a pseudo-count here to ease calculations below
210         #TODO: figure out why this was done in prior implementation...
211         uniqueCountDict[fields[0]] = float(fields[-1]) + 1
212
213     uniquecounts.close()
214
215     genomeName = genome.genome
216     geneinfoDict = getGeneInfoDict(genomeName, cache=doCache)
217     geneannotDict = genome.allAnnotInfo()
218     outfile = open(outFilename, "w")
219     for gid in allGIDs:
220         symbol = getGeneSymbol(gid, genomeName, geneinfoDict, geneannotDict)
221         tagCount = getTagCount(uniqueCountDict, gid, gidReadDict, read2GidDict)
222         if doVerbose:
223             print "%s %s %f" % (gid, symbol, tagCount)
224
225         outfile.write("%s\t%s\t%d\n" % (gid, symbol, tagCount))
226
227     outfile.close()
228
229
230 def getGeneSymbol(gid, genomeName, geneinfoDict, geneannotDict):
231     if "FAR" not in gid:
232         symbol = "LOC%s" % gid
233         geneinfo = ""
234         try:
235             geneinfo = geneinfoDict[gid]
236             if genomeName == "celegans":
237                 symbol = geneinfo[0][1]
238             else:
239                 symbol = geneinfo[0][0]
240         except (KeyError, IndexError):
241             try:
242                 symbol = geneannotDict[(genomeName, gid)][0]
243             except (KeyError, IndexError):
244                 symbol = "LOC%s" % gid
245     else:
246         symbol = gid
247
248     return symbol
249
250
251 def getTagCount(uniqueCountDict, gid, gidReadDict, read2GidDict):
252     tagCount = 0.
253     for readID in gidReadDict[gid]:
254         try:
255             tagValue = uniqueCountDict[gid]
256         except KeyError:
257             pass
258
259         tagDenom = 0.
260         for relatedGID in read2GidDict[readID]:
261             try:
262                 tagDenom += uniqueCountDict[relatedGID]
263             except KeyError:
264                 tagDenom += 1
265
266         try:
267             tagCount += tagValue / tagDenom
268         except ZeroDivisionError:
269             pass
270
271     return tagCount
272
273
274 if __name__ == "__main__":
275     main(sys.argv)