Imported Upstream version 0.1.15
[samtools.git] / bcftools / prob1.c
index 503a998457469cd2081ea3bf90b6aea3a96fea02..a024d041ea80baeb3e0c75427c148aae0db728d0 100644 (file)
@@ -10,7 +10,7 @@
 KSTREAM_INIT(gzFile, gzread, 16384)
 
 #define MC_MAX_EM_ITER 16
-#define MC_EM_EPS 1e-4
+#define MC_EM_EPS 1e-5
 #define MC_DEF_INDEL 0.15
 
 unsigned char seq_nt4_table[256] = {
@@ -209,21 +209,39 @@ static int cal_pdg(const bcf1_t *b, bcf_p1aux_t *ma)
        return i;
 }
 // f0 is the reference allele frequency
-static double mc_freq_iter(double f0, const bcf_p1aux_t *ma)
+static double mc_freq_iter(double f0, const bcf_p1aux_t *ma, int beg, int end)
 {
        double f, f3[3];
        int i;
        f3[0] = (1.-f0)*(1.-f0); f3[1] = 2.*f0*(1.-f0); f3[2] = f0*f0;
-       for (i = 0, f = 0.; i < ma->n; ++i) {
+       for (i = beg, f = 0.; i < end; ++i) {
                double *pdg;
                pdg = ma->pdg + i * 3;
                f += (pdg[1] * f3[1] + 2. * pdg[2] * f3[2])
                        / (pdg[0] * f3[0] + pdg[1] * f3[1] + pdg[2] * f3[2]);
        }
-       f /= ma->n * 2.;
+       f /= (end - beg) * 2.;
        return f;
 }
 
+static double mc_gtfreq_iter(double g[3], const bcf_p1aux_t *ma, int beg, int end)
+{
+       double err, gg[3];
+       int i;
+       gg[0] = gg[1] = gg[2] = 0.;
+       for (i = beg; i < end; ++i) {
+               double *pdg, sum, tmp[3];
+               pdg = ma->pdg + i * 3;
+               tmp[0] = pdg[0] * g[0]; tmp[1] = pdg[1] * g[1]; tmp[2] = pdg[2] * g[2];
+               sum = (tmp[0] + tmp[1] + tmp[2]) * (end - beg);
+               gg[0] += tmp[0] / sum; gg[1] += tmp[1] / sum; gg[2] += tmp[2] / sum;
+       }
+       err = fabs(gg[0] - g[0]) > fabs(gg[1] - g[1])? fabs(gg[0] - g[0]) : fabs(gg[1] - g[1]);
+       err = err > fabs(gg[2] - g[2])? err : fabs(gg[2] - g[2]);
+       g[0] = gg[0]; g[1] = gg[1]; g[2] = gg[2];
+       return err;
+}
+
 int bcf_p1_call_gt(const bcf_p1aux_t *ma, double f0, int k)
 {
        double sum, g[3];
@@ -448,6 +466,8 @@ static double contrast2(bcf_p1aux_t *p1, double ret[3])
                        for (k1 = 0, z = 0.; k1 <= 2*n1; ++k1)
                                for (k2 = 0; k2 <= 2*n2; ++k2)
                                        if ((y = contrast2_aux(p1, sum, n1, n2, k1, k2, ret)) >= 0) z += y;
+                       if (ret[0] + ret[1] + ret[2] < 0.99) // It seems that this may be caused by floating point errors. I do not really understand why...
+                               z = 1.0, ret[0] = ret[1] = ret[2] = 1./3;
                }
                return (double)z;
        }
@@ -516,10 +536,27 @@ int bcf_p1_cal(const bcf1_t *b, bcf_p1aux_t *ma, bcf_p1rst_t *rst)
        { // calculate f_em
                double flast = rst->f_flat;
                for (i = 0; i < MC_MAX_EM_ITER; ++i) {
-                       rst->f_em = mc_freq_iter(flast, ma);
+                       rst->f_em = mc_freq_iter(flast, ma, 0, ma->n);
                        if (fabs(rst->f_em - flast) < MC_EM_EPS) break;
                        flast = rst->f_em;
                }
+               if (ma->n1 > 0 && ma->n1 < ma->n) {
+                       for (k = 0; k < 2; ++k) {
+                               flast = rst->f_em;
+                               for (i = 0; i < MC_MAX_EM_ITER; ++i) {
+                                       rst->f_em2[k] = k? mc_freq_iter(flast, ma, ma->n1, ma->n) : mc_freq_iter(flast, ma, 0, ma->n1);
+                                       if (fabs(rst->f_em2[k] - flast) < MC_EM_EPS) break;
+                                       flast = rst->f_em2[k];
+                               }
+                       }
+               }
+       }
+       { // compute g[3]
+               rst->g[0] = (1. - rst->f_em) * (1. - rst->f_em);
+               rst->g[1] = 2. * rst->f_em * (1. - rst->f_em);
+               rst->g[2] = rst->f_em * rst->f_em;
+               for (i = 0; i < MC_MAX_EM_ITER; ++i)
+                       if (mc_gtfreq_iter(rst->g, ma, 0, ma->n) < MC_EM_EPS) break;
        }
        { // estimate equal-tail credible interval (95% level)
                int l, h;
@@ -534,7 +571,6 @@ int bcf_p1_cal(const bcf1_t *b, bcf_p1aux_t *ma, bcf_p1rst_t *rst)
                h = i;
                rst->cil = (double)(ma->M - h) / ma->M; rst->cih = (double)(ma->M - l) / ma->M;
        }
-       rst->g[0] = rst->g[1] = rst->g[2] = -1.;
        rst->cmp[0] = rst->cmp[1] = rst->cmp[2] = rst->p_chi2 = -1.0;
        if (rst->p_var > 0.1) // skip contrast2() if the locus is a strong non-variant
                rst->p_chi2 = contrast2(ma, rst->cmp);