Ignore emacs backup files of tsvs in addition to python files
[helpful_scripts.git] / translate_tsv_genes.py
1 #!/usr/bin/env python3
2 """Generate a quantification matrix 
3
4 This is intended to extract one quantification column
5 from each of a set of gene quantification files.
6 """
7 # Copyright (2015) Diane Trout & California Institute of Technology
8
9 # This program is free software; you can redistribute it and/or modify
10 # it under the terms of the GNU General Public License as published by
11 # the Free Software Foundation; either version 2 of the License, or
12 # (at your option) any later version.
13 #
14 # This program is distributed in the hope that it will be useful,
15 # but WITHOUT ANY WARRANTY; without even the implied warranty of
16 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
17 # GNU General Public License for more details.
18 #
19 # You should have received a copy of the GNU General Public License along
20 # with this program; if not, write to the Free Software Foundation, Inc.,
21 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
22
23 import argparse
24 import collections
25 import logging
26 import gzip
27 import os
28 from os.path import basename
29 import sys
30 import time
31 import pandas
32
33 logger = logging.getLogger('extractor')
34
35 def main(cmdline=None):
36     parser = make_parser()
37     args = parser.parse_args(cmdline)
38
39     if args.verbose:
40         logging.basicConfig(level=logging.INFO)
41
42     geneid_map = build_geneid_to_gene(args.gtf) if args.gtf else {}
43
44     if not args.quantifications:
45         parser.error("Please list files to extract quantifications from")
46
47     matrix = load_matrixes(args.quantifications, args.column)
48
49     matrix = filter_matrix(matrix, args.no_zeros)
50
51     if args.output:
52         outstream = open(args.output, 'wt')
53     else:
54         outstream = sys.stdout
55
56     write_merged_matrix(outstream, geneid_map, matrix)
57
58     if args.output:
59         outstream.close()
60
61
62 def load_matrixes(quantifications, column_name):
63     """Load a quantification from a list of quantification files.
64
65     This will also convert through a gene id to gene_name map.
66     if a gene name isn't found, it will default to the gene id.
67
68     Arguments:
69         quantifications (list): list of filenames to load from
70         column_name (str): what column we should be looking for
71
72     Returns:
73         matrix (pandas.DataFrame): selected quantification values
74             by gene.
75     """
76     columns = collections.OrderedDict()
77     start = time.time()
78     for quantification in quantifications:
79         name = basename(getattr(quantification, 'name', quantification))
80         logger.info("Loading %s", name)
81         table = pandas.read_csv(quantification, index_col=0, sep='\t')
82         columns[name] = table[column_name]
83
84     matrix = pandas.DataFrame(columns, columns=columns.keys())
85
86     logger.info("Loaded %d matrixes in %d seconds",
87                 len(quantifications),
88                 time.time() - start)
89     return matrix
90
91
92 def filter_matrix(matrix, drop_zeros):
93     """apply transformations to matrix
94
95     Should we drop rows with all zero or NaN?
96
97     Arguments:
98       matrix (pandas.DataFrame): source matrix
99       drop_zeros (bool): should we drop rows that are all zero?
100     """
101     if drop_zeros:
102         matrix = matrix[matrix > 0].dropna(how='all')
103
104     return matrix.fillna(0)
105
106
107 def write_merged_matrix(outstream, geneid_map, matrix,
108                         drop_zeros=False):
109     """Save matrix
110
111     Arguments:
112         outstream (stream): output to write to
113         geneid_map (dict): gene id to gene name mapping
114         matrix (dict): gene_name: list of interested
115     """
116     logger.info("Writing matrix")
117
118     headers = ['#gene_id']
119     headers.extend(matrix.keys())
120     outstream.write('\t'.join(headers))
121     outstream.write(os.linesep)
122
123     for index in matrix.index:
124         label = []
125         gene_name = geneid_map.get(index, None)
126         if gene_name:
127             label.append(gene_name)
128         label.append(index)
129
130         outstream.write('-'.join(label))
131         outstream.write('\t')
132         outstream.write('\t'.join((str(x) for x in matrix.ix[index])))
133         outstream.write(os.linesep)
134
135
136 def make_parser():
137     """Build argument parser.
138     """
139     parser = argparse.ArgumentParser()
140     parser.add_argument('--gtf', help='gtf file to load')
141     parser.add_argument('--column', default='FPKM',
142                         help='which column to use')
143     parser.add_argument('-o', '--output',
144                         help='filename to write merged matrix to')
145     parser.add_argument('--no-zeros', default=False, action='store_true',
146                         help='Drop rows that are all zero')
147     parser.add_argument('-v', '--verbose', default=False,
148                         action='store_true',
149                         help='report progress')
150     parser.add_argument('quantifications', nargs='*',
151                         help='list of quantification files to load')
152     
153     return parser
154
155
156 def build_geneid_to_gene(gencode):
157     """Build a dictionary mapping from gene_id to gene_name.
158
159     Arguments:
160         gencode (str): compressed filename to read
161
162     Returns:
163         dictionary mapping gene_id to gene_name
164     """
165     logger.info("Loading %s", gencode)
166     start = time.time()
167     names = {}
168     with gzip.GzipFile(gencode, 'r') as instream:
169         for line in instream:
170             line = line.decode('ascii')
171             
172             if line.startswith('#'):
173                 continue
174             
175             columns = line.split('\t')
176
177             gene_id = None
178             gene_name = None
179             
180             for item in columns[-1].split(';'):
181                 item = item.strip()
182                 if len(item) == 0:
183                     continue
184                 item = item.split()
185                 if len(item) != 2:
186                     print("Confused: {} {}".format(item, len(item)))
187                 name, value = item
188                 if name == 'gene_id':
189                     gene_id = value[1:-1]
190                 elif name == 'gene_name':
191                     gene_name = value[1:-1]
192
193             if gene_id and gene_name:
194                 names[gene_id] = gene_name
195
196     logger.info("loaded in %d seconds", time.time() - start)
197     return names
198                         
199
200 if __name__ == '__main__':
201     main()
202