# New usptream release; no change for licenses or copyrights.
[samtools.git] / kprobaln.c
1 /* The MIT License
2
3    Copyright (c) 2003-2006, 2008-2010, by Heng Li <lh3lh3@live.co.uk>
4
5    Permission is hereby granted, free of charge, to any person obtaining
6    a copy of this software and associated documentation files (the
7    "Software"), to deal in the Software without restriction, including
8    without limitation the rights to use, copy, modify, merge, publish,
9    distribute, sublicense, and/or sell copies of the Software, and to
10    permit persons to whom the Software is furnished to do so, subject to
11    the following conditions:
12
13    The above copyright notice and this permission notice shall be
14    included in all copies or substantial portions of the Software.
15
16    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18    MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
20    BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
21    ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
22    CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23    SOFTWARE.
24 */
25
26 #include <stdlib.h>
27 #include <stdio.h>
28 #include <string.h>
29 #include <stdint.h>
30 #include <math.h>
31 #include "kprobaln.h"
32
33 /*****************************************
34  * Probabilistic banded glocal alignment *
35  *****************************************/
36
37 #define EI .25
38 #define EM .33333333333
39
40 static float g_qual2prob[256];
41
42 #define set_u(u, b, i, k) { int x=(i)-(b); x=x>0?x:0; (u)=((k)-x+1)*3; }
43
44 kpa_par_t kpa_par_def = { 0.001, 0.1, 10 };
45 kpa_par_t kpa_par_alt = { 0.0001, 0.01, 10 };
46
47 /*
48   The topology of the profile HMM:
49
50            /\             /\        /\             /\
51            I[1]           I[k-1]    I[k]           I[L]
52             ^   \      \    ^    \   ^   \      \   ^
53             |    \      \   |     \  |    \      \  |
54     M[0]   M[1] -> ... -> M[k-1] -> M[k] -> ... -> M[L]   M[L+1]
55                 \      \/        \/      \/      /
56                  \     /\        /\      /\     /
57                        -> D[k-1] -> D[k] ->
58
59    M[0] points to every {M,I}[k] and every {M,I}[k] points M[L+1].
60
61    On input, _ref is the reference sequence and _query is the query
62    sequence. Both are sequences of 0/1/2/3/4 where 4 stands for an
63    ambiguous residue. iqual is the base quality. c sets the gap open
64    probability, gap extension probability and band width.
65
66    On output, state and q are arrays of length l_query. The higher 30
67    bits give the reference position the query base is matched to and the
68    lower two bits can be 0 (an alignment match) or 1 (an
69    insertion). q[i] gives the phred scaled posterior probability of
70    state[i] being wrong.
71  */
72 int kpa_glocal(const uint8_t *_ref, int l_ref, const uint8_t *_query, int l_query, const uint8_t *iqual,
73                            const kpa_par_t *c, int *state, uint8_t *q)
74 {
75         double **f, **b = 0, *s, m[9], sI, sM, bI, bM, pb;
76         float *qual, *_qual;
77         const uint8_t *ref, *query;
78         int bw, bw2, i, k, is_diff = 0, is_backward = 1, Pr;
79
80         /*** initialization ***/
81         is_backward = state && q? 1 : 0;
82         ref = _ref - 1; query = _query - 1; // change to 1-based coordinate
83         bw = l_ref > l_query? l_ref : l_query;
84         if (bw > c->bw) bw = c->bw;
85         if (bw < abs(l_ref - l_query)) bw = abs(l_ref - l_query);
86         bw2 = bw * 2 + 1;
87         // allocate the forward and backward matrices f[][] and b[][] and the scaling array s[]
88         f = calloc(l_query+1, sizeof(void*));
89         if (is_backward) b = calloc(l_query+1, sizeof(void*));
90         for (i = 0; i <= l_query; ++i) {
91                 f[i] = calloc(bw2 * 3 + 6, sizeof(double)); // FIXME: this is over-allocated for very short seqs
92                 if (is_backward) b[i] = calloc(bw2 * 3 + 6, sizeof(double));
93         }
94         s = calloc(l_query+2, sizeof(double)); // s[] is the scaling factor to avoid underflow
95         // initialize qual
96         _qual = calloc(l_query, sizeof(float));
97         if (g_qual2prob[0] == 0)
98                 for (i = 0; i < 256; ++i)
99                         g_qual2prob[i] = pow(10, -i/10.);
100         for (i = 0; i < l_query; ++i) _qual[i] = g_qual2prob[iqual? iqual[i] : 30];
101         qual = _qual - 1;
102         // initialize transition probability
103         sM = sI = 1. / (2 * l_query + 2); // the value here seems not to affect results; FIXME: need proof
104         m[0*3+0] = (1 - c->d - c->d) * (1 - sM); m[0*3+1] = m[0*3+2] = c->d * (1 - sM);
105         m[1*3+0] = (1 - c->e) * (1 - sI); m[1*3+1] = c->e * (1 - sI); m[1*3+2] = 0.;
106         m[2*3+0] = 1 - c->e; m[2*3+1] = 0.; m[2*3+2] = c->e;
107         bM = (1 - c->d) / l_ref; bI = c->d / l_ref; // (bM+bI)*l_ref==1
108         /*** forward ***/
109         // f[0]
110         set_u(k, bw, 0, 0);
111         f[0][k] = s[0] = 1.;
112         { // f[1]
113                 double *fi = f[1], sum;
114                 int beg = 1, end = l_ref < bw + 1? l_ref : bw + 1, _beg, _end;
115                 for (k = beg, sum = 0.; k <= end; ++k) {
116                         int u;
117                         double e = (ref[k] > 3 || query[1] > 3)? 1. : ref[k] == query[1]? 1. - qual[1] : qual[1] * EM;
118                         set_u(u, bw, 1, k);
119                         fi[u+0] = e * bM; fi[u+1] = EI * bI;
120                         sum += fi[u] + fi[u+1];
121                 }
122                 // rescale
123                 s[1] = sum;
124                 set_u(_beg, bw, 1, beg); set_u(_end, bw, 1, end); _end += 2;
125                 for (k = _beg; k <= _end; ++k) fi[k] /= sum;
126         }
127         // f[2..l_query]
128         for (i = 2; i <= l_query; ++i) {
129                 double *fi = f[i], *fi1 = f[i-1], sum, qli = qual[i];
130                 int beg = 1, end = l_ref, x, _beg, _end;
131                 uint8_t qyi = query[i];
132                 x = i - bw; beg = beg > x? beg : x; // band start
133                 x = i + bw; end = end < x? end : x; // band end
134                 for (k = beg, sum = 0.; k <= end; ++k) {
135                         int u, v11, v01, v10;
136                         double e;
137                         e = (ref[k] > 3 || qyi > 3)? 1. : ref[k] == qyi? 1. - qli : qli * EM;
138                         set_u(u, bw, i, k); set_u(v11, bw, i-1, k-1); set_u(v10, bw, i-1, k); set_u(v01, bw, i, k-1);
139                         fi[u+0] = e * (m[0] * fi1[v11+0] + m[3] * fi1[v11+1] + m[6] * fi1[v11+2]);
140                         fi[u+1] = EI * (m[1] * fi1[v10+0] + m[4] * fi1[v10+1]);
141                         fi[u+2] = m[2] * fi[v01+0] + m[8] * fi[v01+2];
142                         sum += fi[u] + fi[u+1] + fi[u+2];
143 //                      fprintf(stderr, "F (%d,%d;%d): %lg,%lg,%lg\n", i, k, u, fi[u], fi[u+1], fi[u+2]); // DEBUG
144                 }
145                 // rescale
146                 s[i] = sum;
147                 set_u(_beg, bw, i, beg); set_u(_end, bw, i, end); _end += 2;
148                 for (k = _beg, sum = 1./sum; k <= _end; ++k) fi[k] *= sum;
149         }
150         { // f[l_query+1]
151                 double sum;
152                 for (k = 1, sum = 0.; k <= l_ref; ++k) {
153                         int u;
154                         set_u(u, bw, l_query, k);
155                         if (u < 3 || u >= bw2*3+3) continue;
156                     sum += f[l_query][u+0] * sM + f[l_query][u+1] * sI;
157                 }
158                 s[l_query+1] = sum; // the last scaling factor
159         }
160         { // compute likelihood
161                 double p = 1., Pr1 = 0.;
162                 for (i = 0; i <= l_query + 1; ++i) {
163                         p *= s[i];
164                         if (p < 1e-100) Pr += -4.343 * log(p), p = 1.;
165                 }
166                 Pr1 += -4.343 * log(p * l_ref * l_query);
167                 Pr = (int)(Pr1 + .499);
168                 if (!is_backward) { // skip backward and MAP
169                         for (i = 0; i <= l_query; ++i) free(f[i]);
170                         free(f); free(s); free(_qual);
171                         return Pr;
172                 }
173         }
174         /*** backward ***/
175         // b[l_query] (b[l_query+1][0]=1 and thus \tilde{b}[][]=1/s[l_query+1]; this is where s[l_query+1] comes from)
176         for (k = 1; k <= l_ref; ++k) {
177                 int u;
178                 double *bi = b[l_query];
179                 set_u(u, bw, l_query, k);
180                 if (u < 3 || u >= bw2*3+3) continue;
181                 bi[u+0] = sM / s[l_query] / s[l_query+1]; bi[u+1] = sI / s[l_query] / s[l_query+1];
182         }
183         // b[l_query-1..1]
184         for (i = l_query - 1; i >= 1; --i) {
185                 int beg = 1, end = l_ref, x, _beg, _end;
186                 double *bi = b[i], *bi1 = b[i+1], y = (i > 1), qli1 = qual[i+1];
187                 uint8_t qyi1 = query[i+1];
188                 x = i - bw; beg = beg > x? beg : x;
189                 x = i + bw; end = end < x? end : x;
190                 for (k = end; k >= beg; --k) {
191                         int u, v11, v01, v10;
192                         double e;
193                         set_u(u, bw, i, k); set_u(v11, bw, i+1, k+1); set_u(v10, bw, i+1, k); set_u(v01, bw, i, k+1);
194                         e = (k >= l_ref? 0 : (ref[k+1] > 3 || qyi1 > 3)? 1. : ref[k+1] == qyi1? 1. - qli1 : qli1 * EM) * bi1[v11];
195                         bi[u+0] = e * m[0] + EI * m[1] * bi1[v10+1] + m[2] * bi[v01+2]; // bi1[v11] has been foled into e.
196                         bi[u+1] = e * m[3] + EI * m[4] * bi1[v10+1];
197                         bi[u+2] = (e * m[6] + m[8] * bi[v01+2]) * y;
198 //                      fprintf(stderr, "B (%d,%d;%d): %lg,%lg,%lg\n", i, k, u, bi[u], bi[u+1], bi[u+2]); // DEBUG
199                 }
200                 // rescale
201                 set_u(_beg, bw, i, beg); set_u(_end, bw, i, end); _end += 2;
202                 for (k = _beg, y = 1./s[i]; k <= _end; ++k) bi[k] *= y;
203         }
204         { // b[0]
205                 int beg = 1, end = l_ref < bw + 1? l_ref : bw + 1;
206                 double sum = 0.;
207                 for (k = end; k >= beg; --k) {
208                         int u;
209                         double e = (ref[k] > 3 || query[1] > 3)? 1. : ref[k] == query[1]? 1. - qual[1] : qual[1] * EM;
210                         set_u(u, bw, 1, k);
211                         if (u < 3 || u >= bw2*3+3) continue;
212                     sum += e * b[1][u+0] * bM + EI * b[1][u+1] * bI;
213                 }
214                 set_u(k, bw, 0, 0);
215                 pb = b[0][k] = sum / s[0]; // if everything works as is expected, pb == 1.0
216         }
217         is_diff = fabs(pb - 1.) > 1e-7? 1 : 0;
218         /*** MAP ***/
219         for (i = 1; i <= l_query; ++i) {
220                 double sum = 0., *fi = f[i], *bi = b[i], max = 0.;
221                 int beg = 1, end = l_ref, x, max_k = -1;
222                 x = i - bw; beg = beg > x? beg : x;
223                 x = i + bw; end = end < x? end : x;
224                 for (k = beg; k <= end; ++k) {
225                         int u;
226                         double z;
227                         set_u(u, bw, i, k);
228                         z = fi[u+0] * bi[u+0]; if (z > max) max = z, max_k = (k-1)<<2 | 0; sum += z;
229                         z = fi[u+1] * bi[u+1]; if (z > max) max = z, max_k = (k-1)<<2 | 1; sum += z;
230                 }
231                 max /= sum; sum *= s[i]; // if everything works as is expected, sum == 1.0
232                 if (state) state[i-1] = max_k;
233                 if (q) k = (int)(-4.343 * log(1. - max) + .499), q[i-1] = k > 100? 99 : k;
234 #ifdef _MAIN
235                 fprintf(stderr, "(%.10lg,%.10lg) (%d,%d:%c,%c:%d) %lg\n", pb, sum, i-1, max_k>>2,
236                                 "ACGT"[query[i]], "ACGT"[ref[(max_k>>2)+1]], max_k&3, max); // DEBUG
237 #endif
238         }
239         /*** free ***/
240         for (i = 0; i <= l_query; ++i) {
241                 free(f[i]); free(b[i]);
242         }
243         free(f); free(b); free(s); free(_qual);
244         return Pr;
245 }
246
247 #ifdef _MAIN
248 #include <unistd.h>
249 int main(int argc, char *argv[])
250 {
251         uint8_t conv[256], *iqual, *ref, *query;
252         int c, l_ref, l_query, i, q = 30, b = 10, P;
253         while ((c = getopt(argc, argv, "b:q:")) >= 0) {
254                 switch (c) {
255                 case 'b': b = atoi(optarg); break;
256                 case 'q': q = atoi(optarg); break;
257                 }
258         }
259         if (optind + 2 > argc) {
260                 fprintf(stderr, "Usage: %s [-q %d] [-b %d] <ref> <query>\n", argv[0], q, b); // example: acttc attc
261                 return 1;
262         }
263         memset(conv, 4, 256);
264         conv['a'] = conv['A'] = 0; conv['c'] = conv['C'] = 1;
265         conv['g'] = conv['G'] = 2; conv['t'] = conv['T'] = 3;
266         ref = (uint8_t*)argv[optind]; query = (uint8_t*)argv[optind+1];
267         l_ref = strlen((char*)ref); l_query = strlen((char*)query);
268         for (i = 0; i < l_ref; ++i) ref[i] = conv[ref[i]];
269         for (i = 0; i < l_query; ++i) query[i] = conv[query[i]];
270         iqual = malloc(l_query);
271         memset(iqual, q, l_query);
272         kpa_par_def.bw = b;
273         P = kpa_glocal(ref, l_ref, query, l_query, iqual, &kpa_par_alt, 0, 0);
274         fprintf(stderr, "%d\n", P);
275         free(iqual);
276         return 0;
277 }
278 #endif