Added MACS source
[htsworkflow.git] / htswanalysis / MACS / lib / PeakModel.py
1 # Time-stamp: <2008-05-28 13:14:16 Tao Liu>
2
3 """Module Description
4
5 Copyright (c) 2008 Yong Zhang, Tao Liu <taoliu@jimmy.harvard.edu>
6
7 This code is free software; you can redistribute it and/or modify it
8 under the terms of the Artistic License (see the file COPYING included
9 with the distribution).
10
11 @status:  experimental
12 @version: $Revision$
13 @author:  Yong Zhang, Tao Liu
14 @contact: taoliu@jimmy.harvard.edu
15 """
16 import sys, time, random
17 import logging
18
19 class PeakModel:
20     """Peak Model class.
21     """
22     def __init__ (self, treatment=None, gz = 0, fold=32, max_pairnum=500, bw=200, ts = 25, bg=0, max_dup_tags=1):
23         self.treatment = treatment
24         self.gz = gz
25         self.fold = fold
26         self.max_pairnum = max_pairnum
27         self.bw = bw
28         self.tsize = ts
29         self.bg_redundant_rate = bg
30         self.mr = max_dup_tags
31         self.summary = ""
32         self.plus_line = None
33         self.minus_line = None
34         self.shifted_line = None
35         self.frag_length = None
36         self.scan_window = None
37         self.min_tags = None
38         self.peaksize = None
39         self.build()
40     
41     def build (self):
42         """Build the model.
43
44         prepare self.frag_length, self.scan_window, self.plus_line,
45         self.minus_line and self.shifted_line to use.
46         """
47         self.peaksize = 2*self.bw
48         self.min_tags = float(self.treatment.total) * self.fold * self.peaksize / self.gz /2 # mininum hits on single strand
49         # use treatment data to build model
50         paired_peakpos = self.__paired_peaks ()
51         # select up to 1000 pairs of peaks to build model
52         num_paired_peakpos = 0
53         num_paired_peakpos_remained = self.max_pairnum
54         num_paired_peakpos_picked = 0
55         for c in paired_peakpos.keys():
56             num_paired_peakpos +=len(paired_peakpos[c])
57             if num_paired_peakpos_remained == 0:
58                 paired_peakpos.pop(c)
59             else:
60                 paired_peakpos[c] = paired_peakpos[c][:num_paired_peakpos_remained]
61                 num_paired_peakpos_remained -=  len(paired_peakpos[c])
62                 num_paired_peakpos_picked += len(paired_peakpos[c])
63
64         logging.info("#2 number of paired peaks: %d" % (num_paired_peakpos))
65         if num_paired_peakpos < 100:
66             logging.critical("Too few paired peaks (%d) so I can not build the model! Lower your MFOLD parameter may erase this error." % (num_paired_peakpos))
67             logging.critical("Process is terminated!")
68             sys.exit(1)
69         elif num_paired_peakpos < self.max_pairnum:
70             logging.warn("Fewer paired peaks (%d) than %d! Model may not be build well! Lower your MFOLD parameter may erase this warning. Now I will use %d pairs to build model!" % (num_paired_peakpos,self.max_pairnum,num_paired_peakpos_picked))
71         logging.debug("Use %d pairs to build the model." % (num_paired_peakpos_picked))
72         self.__paired_peak_model(paired_peakpos)
73
74     def __str__ (self):
75         """For debug...
76
77         """
78         return """
79 Summary of Peak Model:
80   Baseline: %d
81   Fragment size: %d
82   Scan window size: %d
83 """ % (self.min_tags,self.frag_length,self.scan_window)
84
85     def __paired_peak_model (self, paired_peakpos):
86         """Use paired peak positions and treatment tag positions to build the model.
87
88         Modify self.(frag_length, model_shift size and scan_window size. and extra, plus_line, minus_line and shifted_line for plotting).
89         """
90         window_size = 1+2*self.peaksize
91         self.plus_line = [0]*window_size
92         self.minus_line = [0]*window_size
93         for chrom in paired_peakpos.keys():
94             paired_peakpos_chrom = paired_peakpos[chrom]
95             tags = self.treatment.get_ranges_by_chr(chrom)
96             tags_plus =  tags[0]
97             tags_minus = tags[1]
98             # every paired peak has plus line and minus line
99             #  add plus_line
100             self.plus_line = self.__model_add_line (paired_peakpos_chrom, tags_plus,self.plus_line)
101             #  add minus_line
102             self.minus_line = self.__model_add_line (paired_peakpos_chrom, tags_minus,self.minus_line)
103
104         # find top 
105         plus_tops = []
106         minus_tops = []
107         plus_max = max(self.plus_line)
108         minus_max = max(self.minus_line)
109         for i in range(window_size):
110             if self.plus_line[i] == plus_max:
111                 plus_tops.append(i)
112             if self.minus_line[i] == minus_max:
113                 minus_tops.append(i)
114         self.frag_length = minus_tops[len(minus_tops)/2] - plus_tops[len(plus_tops)/2] + 1
115         shift_size = self.frag_length/2
116         self.scan_window = max(self.frag_length,self.tsize)*2
117         # a shifted model
118         self.shifted_line = [0]*window_size
119         plus_shifted = [0]*shift_size
120         plus_shifted.extend(self.plus_line[:-1*shift_size])
121         minus_shifted = self.minus_line[shift_size:]
122         minus_shifted.extend([0]*shift_size)
123         for i in range(window_size):
124             self.shifted_line[i]=minus_shifted[i]+plus_shifted[i]
125         return True
126
127     def __model_add_line (self, pos1, pos2, line):
128         """Project each pos in pos2 which is included in
129         [pos1-self.peaksize,pos1+self.peaksize] to the line.
130
131         """
132         i1 = 0                  # index for pos1
133         i2 = 0                  # index for pos2
134         i2_prev = 0             # index for pos2 in previous pos1
135                                 # [pos1-self.peaksize,pos1+self.peaksize]
136                                 # region
137         i1_max = len(pos1)
138         i2_max = len(pos2)
139         last_p2 = -1
140         flag_find_overlap = False
141          
142         while i1<i1_max and i2<i2_max:
143             p1 = pos1[i1]
144             p2 = pos2[i2]
145             if p1-self.peaksize > p2: # move pos2
146                 i2 += 1
147             elif p1+self.peaksize < p2: # move pos1
148                 i1 += 1                 
149                 i2 = i2_prev    # search minus peaks from previous index
150                 flag_find_overlap = False
151             else:               # overlap!
152                 if not flag_find_overlap:
153                     flag_find_overlap = True
154                     i2_prev = i2 # only the first index is recorded
155                 # project
156                 for i in range(p2-p1+self.peaksize-self.tsize/2,p2-p1+self.peaksize+self.tsize/2):
157                     if i>=0 and i<len(line):
158                         line[i]+=1
159                 i2+=1
160         return line
161             
162     def __paired_peaks (self):
163         """Call paired peaks from fwtrackI object.
164
165         Return paired peaks center positions.
166         """
167         chrs = self.treatment.get_chr_names()
168         chrs.sort()
169         paired_peaks_pos = {}
170         for chrom in chrs:
171             logging.debug("Chromosome: %s" % (chrom))
172             tags = self.treatment.get_ranges_by_chr(chrom)
173             counts = self.treatment.get_comments_by_chr(chrom)
174             plus_peaksinfo = self.__naive_find_peaks (tags[0],counts[0])
175             logging.debug("Number of unique tags on + strand: %d" % (len(tags[0])))            
176             logging.debug("Number of peaks in + strand: %d" % (len(plus_peaksinfo)))
177             minus_peaksinfo = self.__naive_find_peaks (tags[1],counts[1])
178             logging.debug("Number of unique tags on - strand: %d" % (len(tags[1])))            
179             logging.debug("Number of peaks in - strand: %d" % (len(minus_peaksinfo)))
180             if not plus_peaksinfo or not minus_peaksinfo:
181                 logging.debug("Chrom %s is discarded!" % (chrom))
182                 continue
183             else:
184                 paired_peaks_pos[chrom] = self.__find_pair_center (plus_peaksinfo, minus_peaksinfo)
185                 logging.debug("Number of paired peaks: %d" %(len(paired_peaks_pos[chrom])))
186         return paired_peaks_pos
187
188     def __find_pair_center (self, pluspeaks, minuspeaks):
189         ip = 0                  # index for plus peaks
190         im = 0                  # index for minus peaks
191         im_prev = 0             # index for minus peaks in previous plus peak
192         pair_centers = []
193         ip_max = len(pluspeaks)
194         im_max = len(minuspeaks)
195         flag_find_overlap = False
196         while ip<ip_max and im<im_max:
197             (pp,pn) = pluspeaks[ip] # for (peakposition, tagnumber in peak)
198             (mp,mn) = minuspeaks[im]
199             if pp-self.peaksize > mp: # move minus
200                 im += 1
201             elif pp+self.peaksize < mp: # move plus
202                 ip += 1                 
203                 im = im_prev    # search minus peaks from previous index
204                 flag_find_overlap = False
205             else:               # overlap!
206                 if not flag_find_overlap:
207                     flag_find_overlap = True
208                     im_prev = im # only the first index is recorded
209                 if float(pn)/mn < 2 and float(pn)/mn > 0.5: # number tags in plus and minus peak region are comparable...
210                     pair_centers.append((pp+mp)/2)
211                 im += 1
212         return pair_centers
213             
214     def __naive_find_peaks (self, taglist, countlist ):
215         """Naively call peaks based on tags counting. The redundant rate in peak region must be less than 2-fold of background( global) redundant rate
216
217         Return peak positions and the tag number in peak region by a tuple list [(pos,num)].
218         """
219         peak_info = []    # store peak pos in every peak region and
220                           # unique tag number in every peak region
221         if len(taglist)<2:
222             return peak_info
223         pos = taglist[0]
224         count = countlist[0]
225         current_tag_list = [pos]   # list to find peak pos
226         current_redundant_tags = max(count-self.mr,0)
227         for i in range(1,len(taglist)):
228             pos = taglist[i]
229             count = countlist[i]
230             if (pos-current_tag_list[0]+1) > self.peaksize: # call peak in current_tag_list
231                 # a peak will be called if redundant tags are less
232                 # than 2*redunant rate of background
233                 current_redundant_rate = float(current_redundant_tags)/(current_redundant_tags+len(current_tag_list))
234                 if current_redundant_rate <= 2*self.bg_redundant_rate:
235                     # a peak will be called if tag number is ge min tags.
236                     if len(current_tag_list) >= self.min_tags:
237                         peak_info.append((self.__naive_peak_pos(current_tag_list),len(current_tag_list)))
238                 current_tag_list = [] # reset current_tag_list
239                 current_redundant_tags = 0 # reset current_redundant_tags number
240             current_tag_list.append(pos)   # add pos while 1. no
241                                            # need to call peak;
242                                            # 2. current_tag_list is []
243             current_redundant_tags += max(count-self.mr,0)
244         return peak_info
245
246     def __naive_peak_pos (self, pos_list ):
247         """Naively calculate the position of peak.
248
249         return the highest peak summit position.
250         """
251         peak_length = pos_list[-1]+1-pos_list[0]+self.tsize
252         start = pos_list[0] -self.tsize/2
253         horizon_line = [0]*peak_length # the line for tags to be projected
254         for pos in pos_list:
255             for pp in range(pos-start-self.tsize/2,pos-start+self.tsize/2): # projected point
256                 horizon_line[pp] += 1
257
258         top_pos = []            # to record the top positions. Maybe > 1
259         top_p_num = 0           # the maximum number of projected points
260         for pp in range(peak_length): # find the peak posistion as the highest point
261             if horizon_line[pp] > top_p_num:
262                 top_p_num = horizon_line[pp]
263                 top_pos = [pp]
264             elif horizon_line[pp] == top_p_num:
265                 top_pos.append(pp)
266         return (top_pos[len(top_pos)/2]+start)