dc84d4b02d836f0a07a0736588dd1142485d10cf
[samtools.git] / bcftools / ld.c
1 #include <stdlib.h>
2 #include <string.h>
3 #include <math.h>
4 #include "bcf.h"
5
6 static double g_q2p[256];
7
8 #define LD_ITER_MAX 50
9 #define LD_ITER_EPS 1e-4
10
11 #define _G1(h, k) ((h>>1&1) + (k>>1&1))
12 #define _G2(h, k) ((h&1) + (k&1))
13
14 // 0: the previous site; 1: the current site
15 static int freq_iter(int n, double *pdg[2], double f[4])
16 {
17         double ff[4];
18         int i, k, h;
19         memset(ff, 0, 4 * sizeof(double));
20         for (i = 0; i < n; ++i) {
21                 double *p[2], sum, tmp;
22                 p[0] = pdg[0] + i * 3; p[1] = pdg[1] + i * 3;
23                 for (k = 0, sum = 0.; k < 4; ++k)
24                         for (h = 0; h < 4; ++h)
25                                 sum += f[k] * f[h] * p[0][_G1(k,h)] * p[1][_G2(k,h)];
26                 for (k = 0; k < 4; ++k) {
27                         tmp = f[0] * (p[0][_G1(0,k)] * p[1][_G2(0,k)] + p[0][_G1(k,0)] * p[1][_G2(k,0)])
28                                 + f[1] * (p[0][_G1(1,k)] * p[1][_G2(1,k)] + p[0][_G1(k,1)] * p[1][_G2(k,1)])
29                                 + f[2] * (p[0][_G1(2,k)] * p[1][_G2(2,k)] + p[0][_G1(k,2)] * p[1][_G2(k,2)])
30                                 + f[3] * (p[0][_G1(3,k)] * p[1][_G2(3,k)] + p[0][_G1(k,3)] * p[1][_G2(k,3)]);
31                         ff[k] += f[k] * tmp / sum;
32                 }
33         }
34         for (k = 0; k < 4; ++k) f[k] = ff[k] / (2 * n);
35         return 0;
36 }
37
38 double bcf_ld_freq(const bcf1_t *b0, const bcf1_t *b1, double f[4])
39 {
40         const bcf1_t *b[2];
41         uint8_t *PL[2];
42         int i, j, PL_len[2], n_smpl;
43         double *pdg[2], flast[4], r;
44         // initialize g_q2p if necessary
45         if (g_q2p[0] == 0.)
46                 for (i = 0; i < 256; ++i)
47                         g_q2p[i] = pow(10., -i / 10.);
48         // initialize others
49         if (b0->n_smpl != b1->n_smpl) return -1; // different number of samples
50         n_smpl = b0->n_smpl;
51         b[0] = b0; b[1] = b1;
52         f[0] = f[1] = f[2] = f[3] = -1.;
53         if (b[0]->n_alleles < 2 || b[1]->n_alleles < 2) return -1; // one allele only
54         // set PL and PL_len
55         for (j = 0; j < 2; ++j) {
56                 const bcf1_t *bj = b[j];
57                 for (i = 0; i < bj->n_gi; ++i) {
58                         if (bj->gi[i].fmt == bcf_str2int("PL", 2)) {
59                                 PL[j] = (uint8_t*)bj->gi[i].data;
60                                 PL_len[j] = bj->gi[i].len;
61                                 break;
62                         }
63                 }
64                 if (i == bj->n_gi) return -1; // no PL
65         }
66         // fill pdg[2]
67         pdg[0] = malloc(3 * n_smpl * sizeof(double));
68         pdg[1] = malloc(3 * n_smpl * sizeof(double));
69         for (j = 0; j < 2; ++j) {
70                 for (i = 0; i < n_smpl; ++i) {
71                         const uint8_t *pi = PL[j] + i * PL_len[j];
72                         double *p = pdg[j] + i * 3;
73                         p[0] = g_q2p[pi[b[j]->n_alleles]]; p[1] = g_q2p[pi[1]]; p[2] = g_q2p[pi[0]];
74                 }
75         }
76         // iteration
77         f[0] = f[1] = f[2] = f[3] = 0.25; // this is a really bad guess...
78         for (j = 0; j < LD_ITER_MAX; ++j) {
79                 double eps = 0;
80                 memcpy(flast, f, 4 * sizeof(double));
81                 freq_iter(n_smpl, pdg, f);
82                 for (i = 0; i < 4; ++i) {
83                         double x = fabs(f[i] - flast[i]);
84                         if (x > eps) eps = x;
85                 }
86                 if (eps < LD_ITER_EPS) break;
87         }
88         // free
89         free(pdg[0]); free(pdg[1]);
90         { // calculate r^2
91                 double p[2], q[2], D;
92                 p[0] = f[0] + f[1]; q[0] = 1 - p[0];
93                 p[1] = f[0] + f[2]; q[1] = 1 - p[1];
94                 D = f[0] * f[3] - f[1] * f[2];
95                 r = sqrt(D * D / (p[0] * p[1] * q[0] * q[1]));
96                 // fprintf(stderr, "R(%lf,%lf,%lf,%lf)=%lf\n", f[0], f[1], f[2], f[3], r2);
97                 if (isnan(r)) r = -1.;
98         }
99         return r;
100 }