a302d654c6450e9ca4056d54e8b0a31670fc06b0
[helpful_scripts.git] / translate_tsv_genes.py
1 #!/usr/bin/env python3
2 """Generate a quantification matrix 
3
4 This is intended to extract one quantification column
5 from each of a set of gene quantification files.
6 """
7 # Copyright (2015) Diane Trout & California Institute of Technology
8
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 2 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License along
20 # with this program; if not, write to the Free Software Foundation, Inc.,
21 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
22
23 import argparse
24 import collections
25 import logging
26 import gzip
27 import os
28 import sys
29 import time
30
31 logger = logging.getLogger('extractor')
32
33 def main(cmdline=None):
34     parser = make_parser()
35     args = parser.parse_args(cmdline)
36
37     if args.verbose:
38         logging.basicConfig(level=logging.INFO)
39
40     geneid_map = build_geneid_to_gene(args.gtf) if args.gtf else {}
41
42     if not args.quantifications:
43         parser.error("Please list files to extract quantifications from")
44         
45     output_headers, matrix = load_matrixes(geneid_map,
46                                            args.quantifications,
47                                            args.column)
48     write_merged_matrix(args.output, output_headers, matrix, args.no_zeros)
49
50
51 def load_matrixes(geneid_map, quantifications, column_name):
52     """Load a quantification from a list of quantification files.
53
54     This will also convert through a gene id to gene_name map.
55     if a gene name isn't found, it will default to the gene id.
56
57     Arguments:
58         geneid_map (dict): mapping between gene ids and gene names
59         quantifications (list): list of filenames to load from
60
61     Returns:
62         output_headers (list): list of column headers for matrix
63             (derived from input filenames)
64         matrix (dict of lists): selected quantification values
65             by gene.
66     """
67     matrix = collections.OrderedDict()
68     output_headers = ['#genes']
69     start = time.time()
70     for quantification in quantifications:
71         logger.info("Loading %s", quantification)
72         with open(quantification, 'rt') as instream:
73             output_headers.append(os.path.basename(quantification))
74             headers = instream.readline().split('\t')
75             try:
76                 column_to_use = headers.index(column_name)
77             except ValueError as e:
78                 raise RuntimeError(
79                     'Error: {} is not one of the column headers {}'.format(
80                     args.column, headers))
81             
82             for line in instream:
83                 columns = line.split('\t')
84                 key = geneid_map.get(columns[0], columns[0])
85                 matrix.setdefault(key, []).append(columns[column_to_use])
86
87     logger.info("Loaded %d matrixes in %d seconds",
88                 len(quantifications),
89                 time.time() - start)
90     return output_headers, matrix
91
92
93 def write_merged_matrix(output, headers, matrix, drop_zeros=False):
94     """Save matrix
95
96     Arguments:
97         output (str): output filename or None for stdout
98         headers (list): list of matrix column headers)
99         matrix (dict): gene_name: list of interested
100         drop_zeros (bool): should we drop rows that are all zero?
101     """
102     logger.info("Writing matrix")
103     if output:
104         outstream = open(output, 'wt')
105     else:
106         outstream = sys.stdout
107
108     outstream.write('\t'.join(headers))
109     outstream.write(os.linesep)
110     for key in matrix:
111         columns = matrix[key]
112         
113         # skip over zero rows
114         if drop_zeros:
115             for x in columns:
116                 if float(x) != 0:
117                     break
118             else:
119                 continue
120             
121         outstream.write(key)
122         outstream.write('\t')
123         outstream.write('\t'.join(matrix[key]))
124         outstream.write(os.linesep)
125
126     if outstream != sys.stdout:
127         outstream.close()
128         
129
130 def make_parser():
131     """Build argument parser.
132     """
133     parser = argparse.ArgumentParser()
134     parser.add_argument('--gtf', help='gtf file to load')
135     parser.add_argument('--column', default='FPKM',
136                         help='which column to use')
137     parser.add_argument('-o', '--output',
138                         help='filename to write merged matrix to')
139     parser.add_argument('--no-zeros', default=False, action='store_true',
140                         help='Drop rows that are all zero')
141     parser.add_argument('-v', '--verbose', default=False,
142                         action='store_true',
143                         help='report progress')
144     parser.add_argument('quantifications', nargs='*',
145                         help='list of quantification files to load')
146     
147     return parser
148
149
150 def build_geneid_to_gene(gencode):
151     """Build a dictionary mapping from gene_id to gene_name.
152
153     Arguments:
154         gencode (str): compressed filename to read
155
156     Returns:
157         dictionary mapping gene_id to gene_name
158     """
159     logger.info("Loading %s", gencode)
160     start = time.time()
161     names = {}
162     with gzip.GzipFile(gencode, 'r') as instream:
163         for line in instream:
164             line = line.decode('ascii')
165             
166             if line.startswith('#'):
167                 continue
168             
169             columns = line.split('\t')
170
171             gene_id = None
172             gene_name = None
173             
174             for item in columns[-1].split(';'):
175                 item = item.strip()
176                 if len(item) == 0:
177                     continue
178                 item = item.split()
179                 if len(item) != 2:
180                     print("Confused: {} {}".format(item, len(item)))
181                 name, value = item
182                 if name == 'gene_id':
183                     gene_id = value[1:-1]
184                 elif name == 'gene_name':
185                     gene_name = value[1:-1]
186
187             if gene_id and gene_name:
188                 names[gene_id] = gene_name
189
190     logger.info("loaded in %d seconds", time.time() - start)
191     return names
192                         
193
194 if __name__ == '__main__':
195     main()
196