Use pandas to handle merging our table.
authorDiane Trout <diane@ghic.org>
Mon, 4 May 2015 22:24:21 +0000 (15:24 -0700)
committerDiane Trout <diane@ghic.org>
Mon, 4 May 2015 22:24:21 +0000 (15:24 -0700)
My dictionary implementation had trouble when there were
different ids in different files.
(It'd end up with lists that were to short)

translate_tsv_genes.py

index 8b9930a3975f0b167d27c70a1a347a5363e96311..68d46b0a34548e4b5175a983066d6a3d85edde6e 100644 (file)
@@ -25,8 +25,10 @@ import collections
 import logging
 import gzip
 import os
 import logging
 import gzip
 import os
+from os.path import basename
 import sys
 import time
 import sys
 import time
+import pandas
 
 logger = logging.getLogger('extractor')
 
 
 logger = logging.getLogger('extractor')
 
@@ -41,23 +43,22 @@ def main(cmdline=None):
 
     if not args.quantifications:
         parser.error("Please list files to extract quantifications from")
 
     if not args.quantifications:
         parser.error("Please list files to extract quantifications from")
-        
-    output_headers, matrix = load_matrixes(args.quantifications,
-                                           args.column)
+
+    matrix = load_matrixes(args.quantifications, args.column)
+
+    matrix = filter_matrix(matrix, args.no_zeros)
+
     if args.output:
         outstream = open(args.output, 'wt')
     else:
         outstream = sys.stdout
 
     if args.output:
         outstream = open(args.output, 'wt')
     else:
         outstream = sys.stdout
 
-    write_merged_matrix(outstream,
-                        geneid_map,
-                        output_headers,
-                        matrix,
-                        args.no_zeros)
+    write_merged_matrix(outstream, geneid_map, matrix)
 
     if args.output:
         outstream.close()
 
 
     if args.output:
         outstream.close()
 
+
 def load_matrixes(quantifications, column_name):
     """Load a quantification from a list of quantification files.
 
 def load_matrixes(quantifications, column_name):
     """Load a quantification from a list of quantification files.
 
@@ -69,74 +70,68 @@ def load_matrixes(quantifications, column_name):
         column_name (str): what column we should be looking for
 
     Returns:
         column_name (str): what column we should be looking for
 
     Returns:
-        output_headers (list): list of column headers for matrix
-            (derived from input filenames)
-        matrix (dict of lists): selected quantification values
+        matrix (pandas.DataFrame): selected quantification values
             by gene.
     """
             by gene.
     """
-    matrix = collections.OrderedDict()
-    output_headers = ['#genes']
+    columns = collections.OrderedDict()
     start = time.time()
     for quantification in quantifications:
     start = time.time()
     for quantification in quantifications:
-        logger.info("Loading %s", quantification)
-        with open(quantification, 'rt') as instream:
-            output_headers.append(os.path.basename(quantification))
-            headers = instream.readline().split('\t')
-            try:
-                column_to_use = headers.index(column_name)
-            except ValueError as e:
-                raise RuntimeError(
-                    'Error: {} is not one of the column headers {}'.format(
-                    args.column, headers))
-            
-            for line in instream:
-                columns = line.split('\t')
-                key = columns[0]
-                matrix.setdefault(key, []).append(columns[column_to_use])
+        name = basename(getattr(quantification, 'name', quantification))
+        logger.info("Loading %s", name)
+        table = pandas.read_csv(quantification, index_col=0, sep='\t')
+        columns[name] = table[column_name]
+
+    matrix = pandas.DataFrame(columns, columns=columns.keys())
 
     logger.info("Loaded %d matrixes in %d seconds",
                 len(quantifications),
                 time.time() - start)
 
     logger.info("Loaded %d matrixes in %d seconds",
                 len(quantifications),
                 time.time() - start)
-    return output_headers, matrix
+    return matrix
+
+
+def filter_matrix(matrix, drop_zeros):
+    """apply transformations to matrix
+
+    Should we drop rows with all zero or NaN?
 
 
+    Arguments:
+      matrix (pandas.DataFrame): source matrix
+      drop_zeros (bool): should we drop rows that are all zero?
+    """
+    if drop_zeros:
+        matrix = matrix[matrix > 0].dropna(how='all')
+
+    return matrix.fillna(0)
 
 
-def write_merged_matrix(outstream, geneid_map, headers, matrix,
+
+def write_merged_matrix(outstream, geneid_map, matrix,
                         drop_zeros=False):
     """Save matrix
 
     Arguments:
         outstream (stream): output to write to
         geneid_map (dict): gene id to gene name mapping
                         drop_zeros=False):
     """Save matrix
 
     Arguments:
         outstream (stream): output to write to
         geneid_map (dict): gene id to gene name mapping
-        headers (list): list of matrix column headers)
         matrix (dict): gene_name: list of interested
         matrix (dict): gene_name: list of interested
-        drop_zeros (bool): should we drop rows that are all zero?
     """
     logger.info("Writing matrix")
 
     """
     logger.info("Writing matrix")
 
+    headers = ['#gene_id']
+    headers.extend(matrix.keys())
     outstream.write('\t'.join(headers))
     outstream.write(os.linesep)
     outstream.write('\t'.join(headers))
     outstream.write(os.linesep)
-    for key in matrix:
-        columns = matrix[key]
-        
-        # skip over zero rows
-        if drop_zeros:
-            for x in columns:
-                if float(x) != 0:
-                    break
-            else:
-                continue
 
 
+    for index in matrix.index:
         label = []
         label = []
-        gene_name = geneid_map.get(key, None)
+        gene_name = geneid_map.get(index, None)
         if gene_name:
             label.append(gene_name)
         if gene_name:
             label.append(gene_name)
-        label.append(key)
-            
+        label.append(index)
+
         outstream.write('-'.join(label))
         outstream.write('\t')
         outstream.write('-'.join(label))
         outstream.write('\t')
-        outstream.write('\t'.join(matrix[key]))
+        outstream.write('\t'.join((str(x) for x in matrix.ix[index])))
         outstream.write(os.linesep)
         outstream.write(os.linesep)
-        
+
 
 def make_parser():
     """Build argument parser.
 
 def make_parser():
     """Build argument parser.