initial (mostly working) version
[helpful_scripts.git] / extractor.py
1 #!/usr/bin/env python3
2 """Generate a quantification matrix 
3
4 This is intended to extract one quantification column
5 from each of a set of gene quantification files.
6 """
7 import argparse
8 import collections
9 import logging
10 import gzip
11 import os
12 import sys
13 import time
14
15 logger = logging.getLogger('extractor')
16
17 def main(cmdline=None):
18     parser = make_parser()
19     args = parser.parse_args(cmdline)
20
21     if args.verbose:
22         logging.basicConfig(level=logging.INFO)
23
24     geneid_map = build_geneid_to_gene(args.gtf) if args.gtf else {}
25
26     if not args.quantifications:
27         parser.error("Please list files to extract quantifications from")
28         
29     output_headers, matrix = load_matrixes(geneid_map,
30                                            args.quantifications,
31                                            args.column)
32     write_merged_matrix(args.output, output_headers, matrix, args.no_zeros)
33
34
35 def load_matrixes(geneid_map, quantifications, column_name):
36     """Load a quantification from a list of quantification files.
37
38     This will also convert through a gene id to gene_name map.
39     if a gene name isn't found, it will default to the gene id.
40
41     Arguments:
42         geneid_map (dict): mapping between gene ids and gene names
43         quantifications (list): list of filenames to load from
44
45     Returns:
46         output_headers (list): list of column headers for matrix
47             (derived from input filenames)
48         matrix (dict of lists): selected quantification values
49             by gene.
50     """
51     matrix = collections.OrderedDict()
52     output_headers = ['#genes']
53     start = time.time()
54     for quantification in quantifications:
55         logger.info("Loading %s", quantification)
56         with open(quantification, 'rt') as instream:
57             output_headers.append(os.path.basename(quantification))
58             headers = instream.readline().split('\t')
59             try:
60                 column_to_use = headers.index(column_name)
61             except ValueError as e:
62                 raise RuntimeError(
63                     'Error: {} is not one of the column headers {}'.format(
64                     args.column, headers))
65             
66             for line in instream:
67                 columns = line.split('\t')
68                 key = geneid_map.get(columns[0], columns[0])
69                 matrix.setdefault(key, []).append(columns[column_to_use])
70
71     logger.info("Loaded %d matrixes in %d seconds",
72                 len(quantifications),
73                 time.time() - start)
74     return output_headers, matrix
75
76
77 def write_merged_matrix(output, headers, matrix, drop_zeros=False):
78     """Save matrix
79
80     Arguments:
81         output (str): output filename or None for stdout
82         headers (list): list of matrix column headers)
83         matrix (dict): gene_name: list of interested
84         drop_zeros (bool): should we drop rows that are all zero?
85     """
86     logger.info("Writing matrix")
87     if output:
88         outstream = open(output, 'wt')
89     else:
90         outstream = sys.stdout
91
92     outstream.write('\t'.join(headers))
93     outstream.write(os.linesep)
94     for key in matrix:
95         columns = matrix[key]
96         
97         # skip over zero rows
98         if drop_zeros:
99             for x in columns:
100                 if float(x) != 0:
101                     break
102             else:
103                 continue
104             
105         outstream.write(key)
106         outstream.write('\t')
107         outstream.write('\t'.join(matrix[key]))
108         outstream.write(os.linesep)
109
110     if outstream != sys.stdout:
111         outstream.close()
112         
113
114 def make_parser():
115     """Build argument parser.
116     """
117     parser = argparse.ArgumentParser()
118     parser.add_argument('--gtf', help='gtf file to load')
119     parser.add_argument('--column', default='FPKM',
120                         help='which column to use')
121     parser.add_argument('-o', '--output',
122                         help='filename to write merged matrix to')
123     parser.add_argument('--no-zeros', default=False, action='store_true',
124                         help='Drop rows that are all zero')
125     parser.add_argument('-v', '--verbose', default=False,
126                         action='store_true',
127                         help='report progress')
128     parser.add_argument('quantifications', nargs='*',
129                         help='list of quantification files to load')
130     
131     return parser
132
133
134 def build_geneid_to_gene(gencode):
135     """Build a dictionary mapping from gene_id to gene_name.
136
137     Arguments:
138         gencode (str): compressed filename to read
139
140     Returns:
141         dictionary mapping gene_id to gene_name
142     """
143     logger.info("Loading %s", gencode)
144     start = time.time()
145     names = {}
146     with gzip.GzipFile(gencode, 'r') as instream:
147         for line in instream:
148             line = line.decode('ascii')
149             
150             if line.startswith('#'):
151                 continue
152             
153             columns = line.split('\t')
154
155             gene_id = None
156             gene_name = None
157             
158             for item in columns[-1].split(';'):
159                 item = item.strip()
160                 if len(item) == 0:
161                     continue
162                 item = item.split()
163                 if len(item) != 2:
164                     print("Confused: {} {}".format(item, len(item)))
165                 name, value = item
166                 if name == 'gene_id':
167                     gene_id = value[1:-1]
168                 elif name == 'gene_name':
169                     gene_name = value[1:-1]
170
171             if gene_id and gene_name:
172                 names[gene_id] = gene_name
173
174     logger.info("loaded in %d seconds", time.time() - start)
175     return names
176                         
177
178 if __name__ == '__main__':
179     main()
180