From: Diane Trout Date: Mon, 4 May 2015 22:24:21 +0000 (-0700) Subject: Use pandas to handle merging our table. X-Git-Url: http://woldlab.caltech.edu/gitweb/?p=helpful_scripts.git;a=commitdiff_plain;h=216f33a91afbb7263476801ec06f636c8c78d702 Use pandas to handle merging our table. My dictionary implementation had trouble when there were different ids in different files. (It'd end up with lists that were to short) --- diff --git a/translate_tsv_genes.py b/translate_tsv_genes.py index 8b9930a..68d46b0 100644 --- a/translate_tsv_genes.py +++ b/translate_tsv_genes.py @@ -25,8 +25,10 @@ import collections import logging import gzip import os +from os.path import basename import sys import time +import pandas logger = logging.getLogger('extractor') @@ -41,23 +43,22 @@ def main(cmdline=None): if not args.quantifications: parser.error("Please list files to extract quantifications from") - - output_headers, matrix = load_matrixes(args.quantifications, - args.column) + + matrix = load_matrixes(args.quantifications, args.column) + + matrix = filter_matrix(matrix, args.no_zeros) + if args.output: outstream = open(args.output, 'wt') else: outstream = sys.stdout - write_merged_matrix(outstream, - geneid_map, - output_headers, - matrix, - args.no_zeros) + write_merged_matrix(outstream, geneid_map, matrix) if args.output: outstream.close() + def load_matrixes(quantifications, column_name): """Load a quantification from a list of quantification files. @@ -69,74 +70,68 @@ def load_matrixes(quantifications, column_name): column_name (str): what column we should be looking for Returns: - output_headers (list): list of column headers for matrix - (derived from input filenames) - matrix (dict of lists): selected quantification values + matrix (pandas.DataFrame): selected quantification values by gene. """ - matrix = collections.OrderedDict() - output_headers = ['#genes'] + columns = collections.OrderedDict() start = time.time() for quantification in quantifications: - logger.info("Loading %s", quantification) - with open(quantification, 'rt') as instream: - output_headers.append(os.path.basename(quantification)) - headers = instream.readline().split('\t') - try: - column_to_use = headers.index(column_name) - except ValueError as e: - raise RuntimeError( - 'Error: {} is not one of the column headers {}'.format( - args.column, headers)) - - for line in instream: - columns = line.split('\t') - key = columns[0] - matrix.setdefault(key, []).append(columns[column_to_use]) + name = basename(getattr(quantification, 'name', quantification)) + logger.info("Loading %s", name) + table = pandas.read_csv(quantification, index_col=0, sep='\t') + columns[name] = table[column_name] + + matrix = pandas.DataFrame(columns, columns=columns.keys()) logger.info("Loaded %d matrixes in %d seconds", len(quantifications), time.time() - start) - return output_headers, matrix + return matrix + + +def filter_matrix(matrix, drop_zeros): + """apply transformations to matrix + + Should we drop rows with all zero or NaN? + Arguments: + matrix (pandas.DataFrame): source matrix + drop_zeros (bool): should we drop rows that are all zero? + """ + if drop_zeros: + matrix = matrix[matrix > 0].dropna(how='all') + + return matrix.fillna(0) -def write_merged_matrix(outstream, geneid_map, headers, matrix, + +def write_merged_matrix(outstream, geneid_map, matrix, drop_zeros=False): """Save matrix Arguments: outstream (stream): output to write to geneid_map (dict): gene id to gene name mapping - headers (list): list of matrix column headers) matrix (dict): gene_name: list of interested - drop_zeros (bool): should we drop rows that are all zero? """ logger.info("Writing matrix") + headers = ['#gene_id'] + headers.extend(matrix.keys()) outstream.write('\t'.join(headers)) outstream.write(os.linesep) - for key in matrix: - columns = matrix[key] - - # skip over zero rows - if drop_zeros: - for x in columns: - if float(x) != 0: - break - else: - continue + for index in matrix.index: label = [] - gene_name = geneid_map.get(key, None) + gene_name = geneid_map.get(index, None) if gene_name: label.append(gene_name) - label.append(key) - + label.append(index) + outstream.write('-'.join(label)) outstream.write('\t') - outstream.write('\t'.join(matrix[key])) + outstream.write('\t'.join((str(x) for x in matrix.ix[index]))) outstream.write(os.linesep) - + def make_parser(): """Build argument parser.