Added code to identify snps and update the gneome accordingly. This code has an insid...
[htsworkflow.git] / htswanalysis / src / GetReadsInSnps / getReadsInSnps.cpp
1 /*
2  * getReadsInSnps:
3  * This program takes a set of snps in a custom tab format, and a set of short mapped reads, and evaluates
4  * the sequencing overlap over those snps. Additionally, a miaxture model is fit and used to classify the 
5  * snps as homozygous or heterozygous.
6  * 
7  * In the final report, the output is:
8  * <snp id> <chromosome> <position> <reference base> <a count> <c count> <g count> <t count> <total count> <snp call>
9  * where snp call is one of:
10  * -1: no call was made (not enough examples to make a call)
11  * 0: the snp is homozygous
12  * 1: the snp is heterozygous
13  *
14  */
15 #include <sys/types.h>
16 #include <iostream>
17 #include <fstream>
18 #include <vector>
19 #include <map>
20 #include <queue>
21 #include <math.h>
22 #include <string>
23
24 #include <gsl/gsl_statistics.h>
25
26 #define WINDOW 25
27 #define PI 3.14159265358979323846
28
29 //#define DEBUG
30
31 using namespace std;
32
33 void split (const string& text, const string& separators, vector<string>& words);
34 char *strrevcomp(const string& input);
35
36 double norm_prob(double x, double mu, double s) { return (1.0)/(s*sqrt(2*PI)) * exp(-0.5*(x-mu)*(x-mu)/(s*s)); }
37
38 class Loci {
39   public:
40     string chr;
41     unsigned int pos;
42
43     Loci(string chr, unsigned int pos) { this->chr = chr; this->pos = pos; }
44     Loci(const Loci& l) { this->chr = l.chr; this->pos = l.pos; }
45     Loci& operator=(const Loci& l) { this->chr = l.chr; this->pos = l.pos; return *this; }
46
47     bool operator<(const Loci& a) const { if(this->chr == a.chr) { return this->pos < a.pos; } else { return this->chr < a.chr; } }
48
49 };
50
51
52 class Read : public Loci {
53   public:
54     string seq;
55
56     Read(string chr, unsigned int pos, string seq) : Loci(chr,pos) { this->seq = seq; }
57     Read(const Read& r) : Loci(r) { this->seq = r.seq; }
58     Read& operator=(const Read& r) { this->chr = r.chr; this->pos = r.pos; this->seq = r.seq; return *this;}
59 };
60
61 typedef vector<Read> Reads;
62
63 class SNP : public Loci {
64   public:
65
66     string name;
67     char reference_base;
68     char consensus[4]; // represent the consensus sequence in order. Most often, only the first 1 or 2 will matter.
69     unsigned int A;
70     unsigned int C;
71     unsigned int G;
72     unsigned int T;
73     unsigned int N;
74     unsigned int total;
75
76     SNP(string name, string chr, unsigned int pos, char reference_base) : Loci(chr,pos) { 
77       this->name = name; 
78       this->A = 0;
79       this->C = 0;
80       this->G = 0;
81       this->T = 0;
82       this->N = 0;
83
84       this->reference_base = reference_base; 
85     }
86
87     SNP(const SNP& h) : Loci(h) {
88       this->name = h.name; 
89       this->A = h.A; this->C = h.C; this->G = h.G; this->T = h.T; this->total = h.total;
90       this->reference_base = h.reference_base; 
91     }
92
93     SNP& operator=(const SNP& h) {
94       this->name = h.name; 
95       this->chr = h.chr; 
96       this->pos = h.pos; 
97       this->A = h.A; this->C = h.C; this->G = h.G; this->T = h.T; this->total = h.total;
98       this->reference_base = h.reference_base; 
99       return *this;
100     }
101
102     void eval_consensus() {
103       // if A is the max
104       if(A >= C & A >= G & A >= T) { consensus[0] = 'A'; 
105         if(C >= G & C >= T) { consensus[1] = 'C'; 
106           if(G >= T) { consensus[2] = 'G'; consensus[3] = 'T'; }
107           else       { consensus[2] = 'T'; consensus[3] = 'G'; }
108         } else if(G >= C & G >= T) { consensus[1] = 'G'; 
109           if(C >= T) { consensus[2] = 'C'; consensus[3] = 'T'; }
110           else       { consensus[2] = 'T'; consensus[3] = 'C'; }
111         } else { consensus[1] = 'T'; 
112           if(C >= G) { consensus[2] = 'C'; consensus[3] = 'G'; }
113           else       { consensus[2] = 'G'; consensus[3] = 'C'; }
114         }
115
116
117       // if C is the max
118       } else if(C >= A & C >= G & C >= T) { consensus[0] = 'C'; 
119         if(A >= G & A >= T) { consensus[1] = 'A'; 
120           if(G >= T) { consensus[2] = 'G'; consensus[3] = 'T'; }
121           else       { consensus[2] = 'T'; consensus[3] = 'G'; }
122         } else if(G >= A & G >= T) { consensus[1] = 'G'; 
123           if(A >= T) { consensus[2] = 'A'; consensus[3] = 'T'; }
124           else       { consensus[2] = 'T'; consensus[3] = 'A'; }
125         } else { consensus[1] = 'T'; 
126           if(A >= G) { consensus[2] = 'A'; consensus[3] = 'G'; }
127           else       { consensus[2] = 'G'; consensus[3] = 'A'; }
128         }
129       } else if(G >= A & G >= C & G >= T) { consensus[0] = 'G'; 
130         if(A >= C & A >= T) { consensus[1] = 'A'; 
131           if(C >= T) { consensus[2] = 'C'; consensus[3] = 'T'; }
132           else       { consensus[2] = 'T'; consensus[3] = 'C'; }
133         } else if(C >= A & C >= T) { consensus[1] = 'C'; 
134           if(A >= T) { consensus[2] = 'A'; consensus[3] = 'T'; }
135           else       { consensus[2] = 'T'; consensus[3] = 'A'; }
136         } else { consensus[1] = 'T'; 
137           if(A >= C) { consensus[2] = 'A'; consensus[3] = 'C'; }
138           else       { consensus[2] = 'C'; consensus[3] = 'A'; }
139         }
140       } else { consensus[0] = 'T'; 
141         if(A >= C & A >= G) { consensus[1] = 'A'; 
142           if(C >= G) { consensus[2] = 'C'; consensus[3] = 'G'; }
143           else       { consensus[2] = 'G'; consensus[3] = 'C'; }
144         } else if(C >= A & C >= G) { consensus[1] = 'C'; 
145           if(A >= G) { consensus[2] = 'A'; consensus[3] = 'G'; }
146           else       { consensus[2] = 'G'; consensus[3] = 'A'; }
147         } else { consensus[1] = 'G'; 
148           if(A >= C) { consensus[2] = 'A'; consensus[3] = 'C'; }
149           else       { consensus[2] = 'C'; consensus[3] = 'A'; }
150         }
151       }
152     }
153
154     void add_read(char nuc) {
155       switch(nuc) {
156         case 'a':
157         case 'A':
158           A++; break;
159         case 'c':
160         case 'C':
161           C++; break;
162         case 'g':
163         case 'G':
164           G++; break;
165         case 't':
166         case 'T':
167           T++; break;
168         default:
169           N++; break;
170       }
171       total++;
172     }
173
174   void clean(unsigned int threshold) {
175     if(A <= threshold) { A = 0; }
176     if(C <= threshold) { C = 0; }
177     if(G <= threshold) { G = 0; }
178     if(T <= threshold) { T = 0; }
179     total = A + C + G + T;
180     eval_consensus();
181   }
182
183   double RE(unsigned int th = 2) { 
184     if(total == 0) { return 0.0; }
185
186     double pA = (double)( ((A<th)?A:0)+1e-10)/(double)total;
187     double pC = (double)( ((C<th)?C:0)+1e-10)/(double)total;
188     double pG = (double)( ((G<th)?G:0)+1e-10)/(double)total;
189     double pT = (double)( ((T<th)?T:0)+1e-10)/(double)total;
190
191     //assume equal distribution of A,C,G,T
192     double l2 = log(2);
193     return pA*log(pA/0.25)/l2 + pC*log(pC/0.25)/l2 + pG*log(pG/0.25)/l2 + pT*log(pT/0.25)/l2;
194   }
195 };
196
197 typedef vector<SNP> SNPs;
198
199 //Class to calulate mixture model. Very not general right now, but should be easy enough to make more general 
200 //if the need arises
201 class GaussianMixture {
202
203 public:
204   double p;
205   double u1;
206   double s1;
207   double u2;
208   double s2;
209   double Q;
210
211   unsigned int N;
212
213   double delta;
214
215   GaussianMixture(SNPs& snps, double delta = 1e-10) {
216     //initialize model
217     this->p = 0.5;
218     //model 1: heterozygous
219     this->u1 = 1.0;
220     this->s1 = 0.5;
221
222     //model 2: homozygous
223     this->u2 = 2.0;
224     this->s2 = 0.5;
225
226     this->delta = delta;
227   }
228
229   bool classify(double x) {
230     return(norm_prob(x,u1,s1) >= norm_prob(x,u2,s2)) ;
231   }
232
233   // Use EM to fit gaussian mixture model to discern heterozygous from homozygous snps
234   void fit(SNPs& snps, unsigned int count_th) {
235     //initialize relative entropy and probabilities
236     vector<double> RE; 
237     vector<double> pr;
238     for(unsigned int i = 0; i < snps.size(); ++i) {
239       if(snps[i].total >= 8) {
240         RE.push_back(snps[i].RE(count_th));
241         pr.push_back(0.5);
242       }
243     }
244
245     this->N = RE.size();
246
247     cerr << this->N << " snps checked\n";
248
249     //calculate initial expectation
250     this->Q = 0.0;
251     for(unsigned int i = 0; i < N; ++i) {
252       Q +=    pr[i]    * (log( this->p ) - log(sqrt(2.0*PI)) - log(this->s1) - (RE[i] - this->u1)*(RE[i] - this->u1)/(2.0*this->s1*this->s1));
253       Q += (1.0-pr[i]) * (log(1-this->p) - log(sqrt(2.0*PI)) - log(this->s2) - (RE[i] - this->u2)*(RE[i] - this->u2)/(2.0*this->s2*this->s2));
254     }
255
256     cerr << "Q: " << this->Q << endl;
257   
258     double Q_new = 0;
259     //expectation maximization to iteratively update pi's and parameters until Q settles down.
260     while(1) {
261       cerr << "loop Q: " << Q << endl;
262       Q_new = 0.0;
263   
264       double p_sum = 0.0, q_sum = 0.0, u1_sum = 0.0, u2_sum = 0.0;
265       for(unsigned int i = 0; i < N; ++i) {
266         pr[i] = pr[i]*norm_prob(RE[i],this->u1,this->s1) / 
267                 (pr[i]*norm_prob(RE[i],this->u1,this->s1) + (1.0 - pr[i])*(norm_prob(RE[i],this->u2,this->s2)));
268   
269         p_sum += pr[i];
270         q_sum += (1.0 - pr[i]);
271   
272         u1_sum += pr[i]*RE[i];
273         u2_sum += (1.0 - pr[i])*RE[i];
274   
275         Q_new += pr[i]      * (log( this->p ) - log(sqrt(2*PI)) - log(this->s1) - (RE[i] - this->u1)*(RE[i] - this->u1)/(2.0*this->s1*this->s1));
276         Q_new += (1.0-pr[i])* (log(1-this->p) - log(sqrt(2*PI)) - log(this->s2) - (RE[i] - this->u2)*(RE[i] - this->u2)/(2.0*this->s2*this->s2));
277       }
278   
279       //update variables of the distributions (interwoven with pi loop to save cpu)
280       this->p  = p_sum / this->N;
281       this->u1 = u1_sum / p_sum;
282       this->u2 = u2_sum / q_sum;
283        
284       double s1_sum = 0.0, s2_sum = 0.0;
285       for(unsigned int i = 0; i < N; ++i) {
286         s1_sum +=    pr[i]    * (RE[i] - this->u1)*(RE[i] - this->u1);
287         s2_sum += (1.0-pr[i]) * (RE[i] - this->u2)*(RE[i] - this->u2);
288       }
289       
290       this->s1 = sqrt(s1_sum/p_sum);
291       this->s2 = sqrt(s2_sum/q_sum); 
292  
293       if(fabs(this->Q - Q_new) < 1e-5) { break; }
294       this->Q = Q_new;
295     }
296     cerr << "Q: " << Q << endl;
297   }
298
299   void print_model() {
300     cout << "Q: " << Q << " p: " << p << " norm(" << u1 << "," << s1 << ");norm(" << u2 << "," << s2 << ")" << endl;
301   }
302 };
303
304
305 ostream &operator<<( ostream &out, const SNP &h ) {
306   out << h.name.c_str() << "\t" << h.chr.c_str() << "\t" << h.pos << "\t" << h.reference_base << "\t" << h.A << "\t" << h.C << "\t" << h.G << "\t" << h.T << "\t" << h.total;
307
308   return out;
309 }
310
311
312 void read_snps(const char* filename, SNPs& snps) {
313   string delim("\t");
314
315   ifstream feat(filename);
316   size_t N = 0;
317   while(feat.peek() != EOF) {
318     char line[1024];
319     feat.getline(line,1024,'\n');
320     N++;
321     string line_str(line);
322     vector<string> fields;
323     split(line_str, delim, fields);
324     if(fields.size() != 4) { cerr << "Error (" << filename << "): wrong number of fields in feature list (line " << N << " has " << fields.size() << " fields)\n"; }
325
326     string name = fields[0];
327     string chr = fields[1];
328     unsigned int pos = atoi(fields[2].c_str());
329     char base = (fields[3])[0];
330
331     SNP snp(name,chr,pos,base);
332     snps.push_back(snp);
333   } 
334
335   //sort the features so we can run through it once
336   std::stable_sort(snps.begin(),snps.end());
337   feat.close();
338
339   cerr << "Found and sorted " << snps.size() << " snps." << endl;
340 }
341
342
343
344 void read_align_file(char* filename, Reads& features) {
345   string delim(" \n");
346   string location_delim(":");
347   char strand_str[2]; strand_str[1] = '\0';
348   ifstream seqs(filename);
349   string name("");
350   while(seqs.peek() != EOF) {
351     char line[2048];
352     seqs.getline(line,2048,'\n');
353
354     string line_str(line);
355     vector<string> fields;
356     split(line_str, delim, fields);
357     if(fields.size() != 7) { continue; }
358
359  
360     vector<string> location; split(fields[3], location_delim, location);
361     string chr = location[0];
362     if(chr == "NA") { continue; }
363     int pos = atoi(location[1].c_str());
364     bool strand = ((fields[4].c_str())[0] == 'F')?0:1;
365
366     string seq;
367     if(strand == 0) { seq = fields[0]; } else { seq = strrevcomp(fields[0]); }
368     Read read(chr,pos,seq); 
369     features.push_back(read);
370   }
371   seqs.close(); 
372
373   //sort the data so we can run through it once
374   std::sort(features.begin(),features.end());
375   cerr << "Found and sorted " << features.size() << " reads." << endl;
376 }
377
378 void count_read_at_snps(SNPs& snps, Reads& data) {
379   SNPs::iterator snp_it = snps.begin();
380
381   //assume, for now, that all reads have the same length
382   unsigned int read_len = data[0].seq.length();
383   
384   for(Reads::iterator i = data.begin(); i != data.end(); ++i) {
385     //skip to first feature after read
386     string start_chr = snp_it->chr;
387     while(snp_it != snps.end() && *snp_it < *i) {
388       snp_it++;
389     }
390     
391     //stop if we have run out of features.
392     if(snp_it == snps.end()) { break; }
393
394     if(i->pos + read_len > snp_it->pos && i->pos <= snp_it->pos) {
395       snp_it->add_read(i->seq[snp_it->pos - i->pos]);
396     }
397   }
398 }
399
400 int main(int argc, char** argv) {
401   if(argc != 4) { cerr << "Usage: " << argv[0] << " snp_file read_file non_reference_output_file\n"; exit(1); }
402
403   char snp_filename[1024]; strcpy(snp_filename,argv[1]);
404   char read_filename[1024]; strcpy(read_filename,argv[2]);
405   char nonref_filename[1024]; strcpy(nonref_filename,argv[3]);
406
407   SNPs snps; read_snps(snp_filename, snps);
408   Reads reads; read_align_file(read_filename, reads);
409
410   count_read_at_snps(snps, reads);
411
412   //fix a guassian mixture model to the snps to classify
413   GaussianMixture g(snps);
414   g.fit(snps, 2);
415
416 #ifdef DEBUG
417   g.print_model();
418 #endif
419
420   ofstream nonref(nonref_filename);
421   int group;
422   for(SNPs::iterator i = snps.begin(); i != snps.end(); ++i) {
423     group = -1; 
424     if(i->total >= 10) { i->eval_consensus(); group = g.classify(i->RE()); }
425       cout << (*i) << "\t" << group << "\t";
426       if(group == 0)      cout << i->consensus[0];
427       else if(group == 1) cout << i->consensus[0] << "," << i->consensus[1];
428
429       if( ( group == 0 && i->consensus[0] != toupper(i->reference_base) ) || group == 1) {
430         //detected difference from consensus sequence
431         nonref <<i->chr << "\t" << i->pos << "\t";
432         if(group == 0) { nonref << i->consensus[0] << endl; }
433         if(group == 1) { 
434           if(i->consensus[0] != toupper(i->reference_base)) {
435             nonref << i->consensus[0] << endl; 
436           } else {
437             nonref << i->consensus[1] << endl; 
438           }
439         }
440       } 
441       cout << endl;
442   }
443   nonref.close();
444 }
445
446 void split (const string& text, const string& separators, vector<string>& words) {
447
448     size_t n     = text.length ();
449     size_t start = text.find_first_not_of (separators);
450
451     while (start < n) {
452         size_t stop = text.find_first_of (separators, start);
453         if (stop > n) stop = n;
454         words.push_back (text.substr (start, stop-start));
455         start = text.find_first_not_of (separators, stop+1);
456     }
457 }
458
459 char *strrevcomp(const string& input)
460 {
461   char* str = new char[input.length()];
462   strcpy(str,input.c_str());
463
464   char *p1, *p2;
465
466   if (! str || ! *str)
467     return str;
468
469   for (p1 = str, p2 = str + strlen(str) - 1; p2 > p1; ++p1, --p2) {
470     *p1 ^= *p2;
471     *p2 ^= *p1;
472     *p1 ^= *p2;
473   }
474
475   for (p1 = str; p1 < str + strlen(str); ++p1) {
476     if(*p1 == 'a' || *p1 == 'A') { *p1 = 'T'; } 
477     else if(*p1 == 'c' || *p1 == 'C') { *p1 = 'G'; } 
478     else if(*p1 == 'g' || *p1 == 'G') { *p1 = 'C'; } 
479     else if(*p1 == 't' || *p1 == 'T') { *p1 = 'A'; } 
480   }
481
482   return str;
483 }
484