Bugfixes
[htsworkflow.git] / htswanalysis / src / GetReadsInSnps / getRegionConsensus.cpp
1 /*
2  * getReadsInSnps:
3  * This program takes a set of snps in a custom tab format, and a set of short mapped reads, and evaluates
4  * the sequencing overlap over those snps. Additionally, a miaxture model is fit and used to classify the 
5  * snps as homozygous or heterozygous.
6  * 
7  * In the final report, the output is:
8  * <snp id> <chromosome> <position> <reference base> <a count> <c count> <g count> <t count> <total count> <snp call>
9  * where snp call is one of:
10  * -1: no call was made (not enough examples to make a call)
11  * 0: the snp is homozygous
12  * 1: the snp is heterozygous
13  *
14  */
15 #include <sys/types.h>
16 #include <iostream>
17 #include <fstream>
18 #include <vector>
19 #include <map>
20 #include <queue>
21 #include <math.h>
22 #include <string>
23 #include <limits.h>
24
25 #include <gsl/gsl_statistics.h>
26
27 #include "../SRLib/chrom_list.h"
28 #include "../SRLib/util.h"
29 #include "../SRLib/loci.h"
30 #include "../SRLib/read.h"
31 #include "../SRLib/nuc.h"
32
33 #define WINDOW 25
34 #define PI 3.14159265358979323846
35
36 #define DEBUG
37
38 #ifdef DEBUG
39 //#include "duma.h"
40 #endif
41
42 using namespace std;
43
44 class Window : public Loci {
45   public:
46     //optional name for the window
47     string name;
48
49     //the consensus sequence
50     string sequence;
51     unsigned int length;
52     vector<Nuc> seq;
53
54     unsigned int reads;
55
56     Window(string name, string chr, unsigned int pos, unsigned int length) : Loci(chr,pos) { 
57       this->name = name;
58       this->length = length; 
59       if(length > 10000) { cerr << "ERROR: window of size " << length << endl; exit(1); }
60
61       this->sequence = string(length,'\0');
62       seq.resize(this->sequence.length());
63       this->reads = 0;
64     }
65
66     ~Window() {
67       seq.clear();
68     }
69
70     Window(const Window& r) : Loci(r) { 
71       this->name = r.name;
72       this->length = r.length; 
73       this->seq = r.seq;
74       this->sequence = r.sequence;
75       this->reads = r.reads;
76     }
77
78     Window& operator=(const Window& r) { 
79      if (this == &r)  return *this;
80       Loci::operator=(r);
81       this->name = r.name;
82       this->length = r.length; 
83       this->sequence = r.sequence;
84       this->seq = r.seq;
85       this->reads = r.reads;
86       return *this;
87     }
88
89     void set_sequence(string s) {
90       this->sequence = s;
91       unsigned int a;
92       //clear out endlines
93       while( (a = (sequence.find("\n"))) != string::npos) { sequence.erase(a,1); }
94       seq.clear(); seq.resize(sequence.length());
95     }
96    
97    string get_sequence() {
98      return this->sequence;
99    }
100
101     void add_read(const Read& r) {
102       if(this->chr != r.chr) return;
103
104       int w_offset = 0;
105       int r_offset = 0;
106       int overlap_len = r.length();
107
108       //if the read begins before the window starts:
109       //  ---(====[====)------]------
110       if(r.pos < this->pos) {
111         r_offset = this->pos - r.pos;
112         w_offset = 0;
113       } else {
114         w_offset = r.pos - this->pos;
115         r_offset = 0;
116       }
117
118       //if the read ends after the window ends 
119       //  ----[----(====]======)-----
120       if(r.pos + r.length() >= this->pos + this->length) {
121         overlap_len = (this->pos + this->length) - r.pos;  
122       } else {
123         overlap_len = r.length();
124       }
125
126       for(; overlap_len > 0; overlap_len--) {
127         seq[w_offset++].add_nuc(r[r_offset++]);
128       }
129
130       this->reads++;
131     }
132
133     void print_consensus(ostream& o) {
134       unsigned int line_len = 100; 
135       o << ">Consensus for: " << name << " (" << this->chr << ":" << this->pos << "-" << this->pos+this->length << ")" << endl;
136
137       for(unsigned int offset = 0; offset < sequence.length(); offset += line_len) {
138         unsigned int max_len = sequence.length() - offset;
139         unsigned int len = (line_len > max_len)?max_len:line_len;
140         o << sequence.substr(offset,len) << endl;
141         for(unsigned int i = offset; i < offset+len; i++) { 
142           char ref = toupper(sequence[i]);
143           char con = toupper(seq[i].consensus());
144           if(con == ' ') {
145             o << ' ';
146           } else if(con == ref) {
147             o << '|';
148           } else {
149             o << '*';
150           }
151         }
152         o << endl;
153         for(unsigned int i = offset; i < offset+len; i++) { o << seq[i].consensus(); }
154         o << endl << endl;
155       }
156     }
157
158     void print_fasta(ostream& o) {
159       unsigned int line_len = 100; 
160         
161       if(sequence.length() != seq.size()) {
162         cerr << "Size mismatch: sequence is " << sequence.length() << " but data has " << seq.size() << endl;
163       }
164
165       string output = "";
166       vector<string> variants;
167
168       for(unsigned int offset = 0; offset < sequence.length(); offset += line_len) {
169         unsigned int max_len = sequence.length() - offset;
170         unsigned int len = (line_len > max_len)?max_len:line_len;
171         for(unsigned int i = offset; i < offset+len; i++) { 
172           if(i >= seq.size()) { 
173             cerr << "Error: offset of " << i << " exceeds read data: " << seq.size() << " in string of " << sequence.length() << endl; 
174             exit(1); 
175           }
176           char con = toupper(seq[i].consensus()); 
177           // weak consensus if lowercase.
178           bool weak_con = seq[i].consensus() != con;
179           if(con == ' ' || weak_con || toupper(con) == toupper(sequence[i])) { 
180             output.append(1,sequence[i]); 
181           } else { 
182             output.append(1,con); 
183             char buff[128];
184             sprintf(buff,"%d:%c>%c",i,sequence[i],con);
185             string var = buff;
186             variants.push_back(var);
187           }
188         }
189         //output += '\n';
190       }
191       o << ">" << this->chr << ":" << this->pos << "-" << this->pos+this->length << "|";
192       for(vector<string>::iterator i = variants.begin(); i != variants.end(); ++i) {
193         o << (*i);
194         if(i+1 != variants.end()) o << "|";
195       }
196       o << endl << output << endl;
197     }
198
199     void print_RE(ostream& o) {
200       for(unsigned int i = 0; i < sequence.length(); i++) {
201           char ref = toupper(sequence[i]);
202           char con = toupper(seq[i].consensus());
203           if(con != ' ' && con != ref) {
204             o << i << ":" << seq[i].consensus() << " (" << seq[i].RE() << ") -- [" << seq[i].A() << "," << seq[i].C() << "," << seq[i].G() << "," << seq[i].T() << "]" << endl;
205           }
206       }
207     }
208
209     void print_logo(ostream& o) {
210       unsigned int max = 0;
211       for(unsigned int i = 0; i < sequence.length(); i++) {
212         if(seq[i].N() > max) { max = seq[i].N(); }
213       }
214    
215       for(unsigned int i = 0; i < max; i++) {
216         for(unsigned int j = 0; j < sequence.length(); j++) {
217           o << seq[j].nth_nuc(i);
218         }
219         o << endl;
220       }
221     }
222 };
223
224 typedef vector<Window> Windows;
225
226 class SNP : public Loci {
227   public:
228
229     string name;
230     char reference_base;
231     char consensus[4]; // represent the consensus sequence in order. Most often, only the first 1 or 2 will matter.
232     unsigned int A;
233     unsigned int C;
234     unsigned int G;
235     unsigned int T;
236     unsigned int N;
237     unsigned int total;
238
239     SNP(string name, string chr, unsigned int pos, char reference_base) : Loci(chr,pos) { 
240       this->name = name; 
241       this->A = 0;
242       this->C = 0;
243       this->G = 0;
244       this->T = 0;
245       this->N = 0;
246
247       this->reference_base = reference_base; 
248     }
249
250     SNP(const SNP& h) : Loci(h) {
251       this->name = h.name; 
252       this->A = h.A; this->C = h.C; this->G = h.G; this->T = h.T; this->total = h.total;
253       this->reference_base = h.reference_base; 
254     }
255
256     SNP& operator=(const SNP& h) {
257       this->name = h.name; 
258       this->chr = h.chr; 
259       this->pos = h.pos; 
260       this->A = h.A; this->C = h.C; this->G = h.G; this->T = h.T; this->total = h.total;
261       this->reference_base = h.reference_base; 
262       return *this;
263     }
264
265     void eval_consensus() {
266       // if A is the max
267       if(A >= C & A >= G & A >= T) { consensus[0] = 'A'; 
268         if(C >= G & C >= T) { consensus[1] = 'C'; 
269           if(G >= T) { consensus[2] = 'G'; consensus[3] = 'T'; }
270           else       { consensus[2] = 'T'; consensus[3] = 'G'; }
271         } else if(G >= C & G >= T) { consensus[1] = 'G'; 
272           if(C >= T) { consensus[2] = 'C'; consensus[3] = 'T'; }
273           else       { consensus[2] = 'T'; consensus[3] = 'C'; }
274         } else { consensus[1] = 'T'; 
275           if(C >= G) { consensus[2] = 'C'; consensus[3] = 'G'; }
276           else       { consensus[2] = 'G'; consensus[3] = 'C'; }
277         }
278
279
280       // if C is the max
281       } else if(C >= A & C >= G & C >= T) { consensus[0] = 'C'; 
282         if(A >= G & A >= T) { consensus[1] = 'A'; 
283           if(G >= T) { consensus[2] = 'G'; consensus[3] = 'T'; }
284           else       { consensus[2] = 'T'; consensus[3] = 'G'; }
285         } else if(G >= A & G >= T) { consensus[1] = 'G'; 
286           if(A >= T) { consensus[2] = 'A'; consensus[3] = 'T'; }
287           else       { consensus[2] = 'T'; consensus[3] = 'A'; }
288         } else { consensus[1] = 'T'; 
289           if(A >= G) { consensus[2] = 'A'; consensus[3] = 'G'; }
290           else       { consensus[2] = 'G'; consensus[3] = 'A'; }
291         }
292       } else if(G >= A & G >= C & G >= T) { consensus[0] = 'G'; 
293         if(A >= C & A >= T) { consensus[1] = 'A'; 
294           if(C >= T) { consensus[2] = 'C'; consensus[3] = 'T'; }
295           else       { consensus[2] = 'T'; consensus[3] = 'C'; }
296         } else if(C >= A & C >= T) { consensus[1] = 'C'; 
297           if(A >= T) { consensus[2] = 'A'; consensus[3] = 'T'; }
298           else       { consensus[2] = 'T'; consensus[3] = 'A'; }
299         } else { consensus[1] = 'T'; 
300           if(A >= C) { consensus[2] = 'A'; consensus[3] = 'C'; }
301           else       { consensus[2] = 'C'; consensus[3] = 'A'; }
302         }
303       } else { consensus[0] = 'T'; 
304         if(A >= C & A >= G) { consensus[1] = 'A'; 
305           if(C >= G) { consensus[2] = 'C'; consensus[3] = 'G'; }
306           else       { consensus[2] = 'G'; consensus[3] = 'C'; }
307         } else if(C >= A & C >= G) { consensus[1] = 'C'; 
308           if(A >= G) { consensus[2] = 'A'; consensus[3] = 'G'; }
309           else       { consensus[2] = 'G'; consensus[3] = 'A'; }
310         } else { consensus[1] = 'G'; 
311           if(A >= C) { consensus[2] = 'A'; consensus[3] = 'C'; }
312           else       { consensus[2] = 'C'; consensus[3] = 'A'; }
313         }
314       }
315     }
316
317     void add_read(char nuc) {
318       switch(nuc) {
319         case 'a':
320         case 'A':
321           A++; break;
322         case 'c':
323         case 'C':
324           C++; break;
325         case 'g':
326         case 'G':
327           G++; break;
328         case 't':
329         case 'T':
330           T++; break;
331         default:
332           N++; break;
333       }
334       total++;
335     }
336
337   void clean(unsigned int threshold) {
338     if(A <= threshold) { A = 0; }
339     if(C <= threshold) { C = 0; }
340     if(G <= threshold) { G = 0; }
341     if(T <= threshold) { T = 0; }
342     total = A + C + G + T;
343     eval_consensus();
344   }
345
346   double RE(unsigned int th = 2) { 
347     if(total == 0) { return 0.0; }
348
349     double pA = (double)( ((A<th)?A:0)+1e-10)/(double)total;
350     double pC = (double)( ((C<th)?C:0)+1e-10)/(double)total;
351     double pG = (double)( ((G<th)?G:0)+1e-10)/(double)total;
352     double pT = (double)( ((T<th)?T:0)+1e-10)/(double)total;
353
354     //assume equal distribution of A,C,G,T
355     double l2 = log(2);
356     return pA*log(pA/0.25)/l2 + pC*log(pC/0.25)/l2 + pG*log(pG/0.25)/l2 + pT*log(pT/0.25)/l2;
357   }
358 };
359
360 typedef vector<SNP> SNPs;
361
362 //Class to calulate mixture model. Very not general right now, but should be easy enough to make more general 
363 //if the need arises
364 class GaussianMixture {
365
366 public:
367   double p;
368   double u1;
369   double s1;
370   double u2;
371   double s2;
372   double Q;
373
374   unsigned int N;
375
376   double delta;
377
378   GaussianMixture(SNPs& snps, double delta = 1e-10) {
379     //initialize model
380     this->p = 0.5;
381     //model 1: heterozygous
382     this->u1 = 1.0;
383     this->s1 = 0.5;
384
385     //model 2: homozygous
386     this->u2 = 2.0;
387     this->s2 = 0.5;
388
389     this->delta = delta;
390   }
391
392   bool classify(double x) {
393     return(norm_prob(x,u1,s1) >= norm_prob(x,u2,s2)) ;
394   }
395
396   // Use EM to fit gaussian mixture model to discern heterozygous from homozygous snps
397   void fit(SNPs& snps, unsigned int count_th) {
398     //initialize relative entropy and probabilities
399     vector<double> RE; 
400     vector<double> pr;
401     for(unsigned int i = 0; i < snps.size(); ++i) {
402       if(snps[i].total >= 8) {
403         RE.push_back(snps[i].RE(count_th));
404         pr.push_back(0.5);
405       }
406     }
407
408     this->N = RE.size();
409
410     cerr << this->N << " snps checked\n";
411
412     //calculate initial expectation
413     this->Q = 0.0;
414     for(unsigned int i = 0; i < N; ++i) {
415       Q +=    pr[i]    * (log( this->p ) - log(sqrt(2.0*PI)) - log(this->s1) - (RE[i] - this->u1)*(RE[i] - this->u1)/(2.0*this->s1*this->s1));
416       Q += (1.0-pr[i]) * (log(1-this->p) - log(sqrt(2.0*PI)) - log(this->s2) - (RE[i] - this->u2)*(RE[i] - this->u2)/(2.0*this->s2*this->s2));
417     }
418
419     cerr << "Q: " << this->Q << endl;
420   
421     double Q_new = 0;
422     //expectation maximization to iteratively update pi's and parameters until Q settles down.
423     while(1) {
424       cerr << "loop Q: " << Q << endl;
425       Q_new = 0.0;
426   
427       double p_sum = 0.0, q_sum = 0.0, u1_sum = 0.0, u2_sum = 0.0;
428       for(unsigned int i = 0; i < N; ++i) {
429         pr[i] = pr[i]*norm_prob(RE[i],this->u1,this->s1) / 
430                 (pr[i]*norm_prob(RE[i],this->u1,this->s1) + (1.0 - pr[i])*(norm_prob(RE[i],this->u2,this->s2)));
431   
432         p_sum += pr[i];
433         q_sum += (1.0 - pr[i]);
434   
435         u1_sum += pr[i]*RE[i];
436         u2_sum += (1.0 - pr[i])*RE[i];
437   
438         Q_new += pr[i]      * (log( this->p ) - log(sqrt(2*PI)) - log(this->s1) - (RE[i] - this->u1)*(RE[i] - this->u1)/(2.0*this->s1*this->s1));
439         Q_new += (1.0-pr[i])* (log(1-this->p) - log(sqrt(2*PI)) - log(this->s2) - (RE[i] - this->u2)*(RE[i] - this->u2)/(2.0*this->s2*this->s2));
440       }
441   
442       //update variables of the distributions (interwoven with pi loop to save cpu)
443       this->p  = p_sum / this->N;
444       this->u1 = u1_sum / p_sum;
445       this->u2 = u2_sum / q_sum;
446        
447       double s1_sum = 0.0, s2_sum = 0.0;
448       for(unsigned int i = 0; i < N; ++i) {
449         s1_sum +=    pr[i]    * (RE[i] - this->u1)*(RE[i] - this->u1);
450         s2_sum += (1.0-pr[i]) * (RE[i] - this->u2)*(RE[i] - this->u2);
451       }
452       
453       this->s1 = sqrt(s1_sum/p_sum);
454       this->s2 = sqrt(s2_sum/q_sum); 
455  
456       if(fabs(this->Q - Q_new) < 1e-5) { break; }
457       this->Q = Q_new;
458     }
459     cerr << "Q: " << Q << endl;
460   }
461
462   void print_model() {
463     cout << "Q: " << Q << " p: " << p << " norm(" << u1 << "," << s1 << ");norm(" << u2 << "," << s2 << ")" << endl;
464   }
465 };
466
467
468 ostream &operator<<( ostream &out, const SNP &h ) {
469   out << h.name.c_str() << "\t" << h.chr.c_str() << "\t" << h.pos << "\t" << h.reference_base << "\t" << h.A << "\t" << h.C << "\t" << h.G << "\t" << h.T << "\t" << h.total;
470
471   return out;
472 }
473
474
475 void read_snps(const char* filename, SNPs& snps) {
476   string delim("\t");
477
478   ifstream feat(filename);
479   size_t N = 0;
480   while(feat.peek() != EOF) {
481     char line[1024];
482     feat.getline(line,1024,'\n');
483     N++;
484     string line_str(line);
485     vector<string> fields;
486     split(line_str, delim, fields);
487     if(fields.size() != 4) { cerr << "Error (" << filename << "): wrong number of fields in feature list (line " << N << " has " << fields.size() << " fields)\n"; }
488
489     string name = fields[0];
490     string chr = fields[1];
491     unsigned int pos = atoi(fields[2].c_str());
492     char base = (fields[3])[0];
493
494     SNP snp(name,chr,pos,base);
495     snps.push_back(snp);
496   } 
497
498   //sort the features so we can run through it once
499   std::stable_sort(snps.begin(),snps.end());
500   feat.close();
501
502   cerr << "Found AND sorted " << snps.size() << " snps." << endl;
503 }
504
505 void read_align_file(char* filename, Reads& features) {
506   string delim(" \n");
507   string location_delim(":");
508   char strand_str[2]; strand_str[1] = '\0';
509   ifstream seqs(filename);
510   string name("");
511   while(seqs.peek() != EOF) {
512     char line[2048];
513     seqs.getline(line,2048,'\n');
514
515     string line_str(line);
516     vector<string> fields;
517     split(line_str, delim, fields);
518     if(fields.size() != 7) { continue; }
519  
520     vector<string> location; split(fields[3], location_delim, location);
521     string chr = location[0];
522     if(chr == "newcontam") { continue; }
523     if(chr == "NA") { continue; }
524
525     int pos = atoi(location[1].c_str());
526     bool strand = ((fields[4].c_str())[0] == 'F')?0:1;
527
528     string seq;
529     if(strand == 0) { seq = fields[0]; } else { revcomp(seq,fields[0]); }
530     Read read(chr,pos,0,seq); 
531     features.push_back(read);
532   }
533   seqs.close(); 
534
535   //sort the data so we can run through it once
536   std::sort(features.begin(),features.end());
537   cerr << "Found and sorted " << features.size() << " reads." << endl;
538 }
539
540 void read_window_file(const char* filename, Windows& ws) {
541   string delim("\t");
542
543   ifstream win_file(filename);
544
545   unsigned int N = 0;
546   while(win_file.peek() != EOF) {
547     char line[1024];
548     win_file.getline(line,1024,'\n');
549     N++;
550     string line_str(line);
551     vector<string> fields;
552     split(line_str, delim, fields);
553     if(fields.size() < 5) { cerr << "Error (" << filename << "): wrong number of fields in feature list (line " << N << " has " << fields.size() << " fields)\n"; }
554
555     string name = fields[0];
556     string chr = fields[1];
557     if(chr == "NA") { continue; }
558     if(chr == "contam") { continue; }
559     int start = atoi(fields[2].c_str());
560     int stop = atoi(fields[3].c_str());
561
562     Window w(name,chr,start,stop-start+1);
563     ws.push_back(w);
564   } 
565
566   //sort the features so we can run through it once
567   std::stable_sort(ws.begin(),ws.end());
568   win_file.close();
569
570   cerr << "Found and sorted " << ws.size() << " windows." << endl;
571 }
572
573 void count_read_in_features(Windows& windows, Reads& data) {
574   Windows::iterator wind_it = windows.begin();
575   
576   for(Reads::iterator i = data.begin(); i != data.end(); ++i) {
577     //skip to first feature after read
578     string start_chr = wind_it->chr;
579     while(wind_it != windows.end() && (wind_it->chr < i->chr || (wind_it->chr == i->chr && wind_it->pos + wind_it->length < i->pos) )) {
580       wind_it++;
581     }
582     
583     //stop if we have run out of features.
584     if(wind_it == windows.end()) { break; }
585
586     if(i->pos + i->length() > wind_it->pos && i->pos < (wind_it->pos + wind_it->length)) {
587       wind_it->add_read(*i);
588     }
589   }
590 }
591
592 void retrieveSequenceData(ChromList chrom_filenames, Windows& peaks) {
593         char temp[1024];
594
595         string chrom = peaks[0].chr;
596         string chrom_filename = chrom_filenames[chrom];
597         ifstream chrom_file(chrom_filename.c_str());
598         chrom_file.getline(temp, 1024);
599         size_t offset = chrom_file.gcount();
600         for(Windows::iterator i = peaks.begin(); i != peaks.end(); ++i) {
601           if(i->chr != chrom) { 
602             chrom = i->chr; 
603             cout << "XXX: " << (*(chrom_filenames.find(chrom))).first << endl;
604             if(chrom_filenames.find(chrom) == chrom_filenames.end()) {
605               cerr << "Chrom: " << chrom << " not found\n";
606             }
607             chrom_filename = chrom_filenames[chrom];
608             chrom_file.close(); chrom_file.open(chrom_filename.c_str());
609             chrom_file.getline(temp, 1024);
610             offset = chrom_file.gcount();
611           }
612           unsigned int begin = i->pos - 1;
613           unsigned int end   = i->pos+i->length;
614  
615           unsigned int begin_pos = offset + (int)begin/50 + begin;      
616           unsigned int end_pos = offset + (int)end/50 + end;      
617
618           unsigned int read_len = end_pos - begin_pos;
619           char buffer[read_len+1];
620           chrom_file.seekg(begin_pos, ios_base::beg);
621           chrom_file.read(buffer, read_len);
622           buffer[read_len] = '\0';
623           i->set_sequence(buffer);
624         } 
625         chrom_file.close(); 
626 }
627
628
629 int main(int argc, char** argv) {
630   if(argc != 4) { cerr << "Usage: " << argv[0] << " read_file window_file chromosome_file\n"; exit(1); }
631
632   char read_filename[1024]; strcpy(read_filename,argv[1]);
633   char window_filename[1024]; strcpy(window_filename,argv[2]);
634   char chromosome_filename[1024]; strcpy(chromosome_filename,argv[3]);
635
636   Windows windows; read_window_file(window_filename, windows);
637
638   if(windows.size() == 0) { cout << "No windows loaded." << endl; exit(0); }
639   ChromList reference_seq(chromosome_filename);
640
641   retrieveSequenceData(reference_seq, windows);
642
643   cerr << "Established reference sequences\n";
644
645   Reads reads; read_align_file(read_filename, reads);
646
647   if(reads.size() == 0) { cout << "No reads loaded." << endl; exit(0); }
648
649   count_read_in_features(windows, reads);
650
651   for(Windows::iterator w = windows.begin(); w != windows.end(); ++w) {
652     //w->print_consensus(cout);
653     //w->print_logo(cout);
654     w->print_RE(cerr);
655     w->print_fasta(cout);
656   }
657 }