X-Git-Url: http://woldlab.caltech.edu/gitweb/?p=helpful_scripts.git;a=blobdiff_plain;f=translate_tsv_genes.py;h=68d46b0a34548e4b5175a983066d6a3d85edde6e;hp=a302d654c6450e9ca4056d54e8b0a31670fc06b0;hb=HEAD;hpb=350df060402bb45ba83f552430043dab853a22ec diff --git a/translate_tsv_genes.py b/translate_tsv_genes.py index a302d65..68d46b0 100644 --- a/translate_tsv_genes.py +++ b/translate_tsv_genes.py @@ -25,8 +25,10 @@ import collections import logging import gzip import os +from os.path import basename import sys import time +import pandas logger = logging.getLogger('extractor') @@ -41,91 +43,95 @@ def main(cmdline=None): if not args.quantifications: parser.error("Please list files to extract quantifications from") - - output_headers, matrix = load_matrixes(geneid_map, - args.quantifications, - args.column) - write_merged_matrix(args.output, output_headers, matrix, args.no_zeros) + + matrix = load_matrixes(args.quantifications, args.column) + + matrix = filter_matrix(matrix, args.no_zeros) + + if args.output: + outstream = open(args.output, 'wt') + else: + outstream = sys.stdout + + write_merged_matrix(outstream, geneid_map, matrix) + + if args.output: + outstream.close() -def load_matrixes(geneid_map, quantifications, column_name): +def load_matrixes(quantifications, column_name): """Load a quantification from a list of quantification files. This will also convert through a gene id to gene_name map. if a gene name isn't found, it will default to the gene id. Arguments: - geneid_map (dict): mapping between gene ids and gene names quantifications (list): list of filenames to load from + column_name (str): what column we should be looking for Returns: - output_headers (list): list of column headers for matrix - (derived from input filenames) - matrix (dict of lists): selected quantification values + matrix (pandas.DataFrame): selected quantification values by gene. """ - matrix = collections.OrderedDict() - output_headers = ['#genes'] + columns = collections.OrderedDict() start = time.time() for quantification in quantifications: - logger.info("Loading %s", quantification) - with open(quantification, 'rt') as instream: - output_headers.append(os.path.basename(quantification)) - headers = instream.readline().split('\t') - try: - column_to_use = headers.index(column_name) - except ValueError as e: - raise RuntimeError( - 'Error: {} is not one of the column headers {}'.format( - args.column, headers)) - - for line in instream: - columns = line.split('\t') - key = geneid_map.get(columns[0], columns[0]) - matrix.setdefault(key, []).append(columns[column_to_use]) + name = basename(getattr(quantification, 'name', quantification)) + logger.info("Loading %s", name) + table = pandas.read_csv(quantification, index_col=0, sep='\t') + columns[name] = table[column_name] + + matrix = pandas.DataFrame(columns, columns=columns.keys()) logger.info("Loaded %d matrixes in %d seconds", len(quantifications), time.time() - start) - return output_headers, matrix + return matrix -def write_merged_matrix(output, headers, matrix, drop_zeros=False): +def filter_matrix(matrix, drop_zeros): + """apply transformations to matrix + + Should we drop rows with all zero or NaN? + + Arguments: + matrix (pandas.DataFrame): source matrix + drop_zeros (bool): should we drop rows that are all zero? + """ + if drop_zeros: + matrix = matrix[matrix > 0].dropna(how='all') + + return matrix.fillna(0) + + +def write_merged_matrix(outstream, geneid_map, matrix, + drop_zeros=False): """Save matrix Arguments: - output (str): output filename or None for stdout - headers (list): list of matrix column headers) + outstream (stream): output to write to + geneid_map (dict): gene id to gene name mapping matrix (dict): gene_name: list of interested - drop_zeros (bool): should we drop rows that are all zero? """ logger.info("Writing matrix") - if output: - outstream = open(output, 'wt') - else: - outstream = sys.stdout + headers = ['#gene_id'] + headers.extend(matrix.keys()) outstream.write('\t'.join(headers)) outstream.write(os.linesep) - for key in matrix: - columns = matrix[key] - - # skip over zero rows - if drop_zeros: - for x in columns: - if float(x) != 0: - break - else: - continue - - outstream.write(key) + + for index in matrix.index: + label = [] + gene_name = geneid_map.get(index, None) + if gene_name: + label.append(gene_name) + label.append(index) + + outstream.write('-'.join(label)) outstream.write('\t') - outstream.write('\t'.join(matrix[key])) + outstream.write('\t'.join((str(x) for x in matrix.ix[index]))) outstream.write(os.linesep) - if outstream != sys.stdout: - outstream.close() - def make_parser(): """Build argument parser.