write matrix subsampler
[helpful_scripts.git] / translate_tsv_genes.py
index 328c8c36491b961c248a95b82056136e27ecb807..8b9930a3975f0b167d27c70a1a347a5363e96311 100644 (file)
@@ -4,6 +4,22 @@
 This is intended to extract one quantification column
 from each of a set of gene quantification files.
 """
+# Copyright (2015) Diane Trout & California Institute of Technology
+
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License along
+# with this program; if not, write to the Free Software Foundation, Inc.,
+# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+
 import argparse
 import collections
 import logging
@@ -26,21 +42,31 @@ def main(cmdline=None):
     if not args.quantifications:
         parser.error("Please list files to extract quantifications from")
         
-    output_headers, matrix = load_matrixes(geneid_map,
-                                           args.quantifications,
+    output_headers, matrix = load_matrixes(args.quantifications,
                                            args.column)
-    write_merged_matrix(args.output, output_headers, matrix, args.no_zeros)
+    if args.output:
+        outstream = open(args.output, 'wt')
+    else:
+        outstream = sys.stdout
 
+    write_merged_matrix(outstream,
+                        geneid_map,
+                        output_headers,
+                        matrix,
+                        args.no_zeros)
+
+    if args.output:
+        outstream.close()
 
-def load_matrixes(geneid_map, quantifications, column_name):
+def load_matrixes(quantifications, column_name):
     """Load a quantification from a list of quantification files.
 
     This will also convert through a gene id to gene_name map.
     if a gene name isn't found, it will default to the gene id.
 
     Arguments:
-        geneid_map (dict): mapping between gene ids and gene names
         quantifications (list): list of filenames to load from
+        column_name (str): what column we should be looking for
 
     Returns:
         output_headers (list): list of column headers for matrix
@@ -65,7 +91,7 @@ def load_matrixes(geneid_map, quantifications, column_name):
             
             for line in instream:
                 columns = line.split('\t')
-                key = geneid_map.get(columns[0], columns[0])
+                key = columns[0]
                 matrix.setdefault(key, []).append(columns[column_to_use])
 
     logger.info("Loaded %d matrixes in %d seconds",
@@ -74,20 +100,18 @@ def load_matrixes(geneid_map, quantifications, column_name):
     return output_headers, matrix
 
 
-def write_merged_matrix(output, headers, matrix, drop_zeros=False):
+def write_merged_matrix(outstream, geneid_map, headers, matrix,
+                        drop_zeros=False):
     """Save matrix
 
     Arguments:
-        output (str): output filename or None for stdout
+        outstream (stream): output to write to
+        geneid_map (dict): gene id to gene name mapping
         headers (list): list of matrix column headers)
         matrix (dict): gene_name: list of interested
         drop_zeros (bool): should we drop rows that are all zero?
     """
     logger.info("Writing matrix")
-    if output:
-        outstream = open(output, 'wt')
-    else:
-        outstream = sys.stdout
 
     outstream.write('\t'.join(headers))
     outstream.write(os.linesep)
@@ -101,14 +125,17 @@ def write_merged_matrix(output, headers, matrix, drop_zeros=False):
                     break
             else:
                 continue
+
+        label = []
+        gene_name = geneid_map.get(key, None)
+        if gene_name:
+            label.append(gene_name)
+        label.append(key)
             
-        outstream.write(key)
+        outstream.write('-'.join(label))
         outstream.write('\t')
         outstream.write('\t'.join(matrix[key]))
         outstream.write(os.linesep)
-
-    if outstream != sys.stdout:
-        outstream.close()
         
 
 def make_parser():