Imported Upstream version 0.7
[pysam.git] / tests / tabix_test.py
1 #!/usr/bin/env python
2 '''unit testing code for pysam.
3
4 Execute in the :file:`tests` directory as it requires the Makefile
5 and data files located there.
6 '''
7
8 import sys, os, shutil, gzip
9 import pysam
10 import unittest
11 import itertools
12 import subprocess
13 import glob
14 import re
15
16 IS_PYTHON3 = sys.version_info[0] >= 3
17
18 def myzip_open( infile, mode = "r" ):
19     '''open compressed file and decode.'''
20
21     def _convert(f):
22         for l in f:
23             yield l.decode("ascii")
24
25     if IS_PYTHON3:
26         if mode == "r":
27             return _convert(gzip.open(infile,"r"))
28     else:
29         return gzip.open( mode )
30
31 def loadAndConvert( infile ):
32     '''load and convert all fields to bytes'''
33     data = []
34     if infile.endswith(".gz"):
35         for line in gzip.open( infile ):
36             line = line.decode("ascii")
37             if line.startswith("#"): continue
38             d = line.strip().split("\t")
39             data.append( [x.encode("ascii") for x in d ] )
40     else:
41         with open(infile) as f:
42             for line in f:
43                 if line.startswith("#"): continue
44                 d = line.strip().split("\t")
45                 data.append( [x.encode("ascii") for x in d ] )
46
47     return data
48
49 def splitToBytes( s ):
50     '''split string and return list of bytes.'''
51     return [x.encode("ascii") for x in s.split("\t")]
52
53 def checkBinaryEqual( filename1, filename2 ):
54     '''return true if the two files are binary equal.'''
55     if os.path.getsize( filename1 ) !=  os.path.getsize( filename2 ):
56         return False
57
58     infile1 = open(filename1, "rb")
59     infile2 = open(filename2, "rb")
60
61     d1, d2 = infile1.read(), infile2.read()
62     found = False
63     for c1,c2 in zip( d1, d2 ):
64         if c1 != c2: break
65     else:
66         found = True
67
68     infile1.close()
69     infile2.close()
70     return found
71
72 class TestIndexing(unittest.TestCase):
73     filename = "example.gtf.gz" 
74     filename_idx = "example.gtf.gz.tbi" 
75
76     def setUp( self ):
77         
78         self.tmpfilename = "tmp_%i.gtf.gz" % id(self)
79         shutil.copyfile( self.filename, self.tmpfilename )
80
81     def testIndexPreset( self ):
82         '''test indexing via preset.'''
83
84         pysam.tabix_index( self.tmpfilename, preset = "gff" )
85         checkBinaryEqual( self.tmpfilename + ".tbi", self.filename_idx )
86
87     def tearDown( self ):
88         os.unlink( self.tmpfilename )
89         os.unlink( self.tmpfilename + ".tbi" )
90
91 class TestCompression(unittest.TestCase):
92     filename = "example.gtf.gz" 
93     filename_idx = "example.gtf.gz.tbi" 
94
95     def setUp( self ):
96         
97         self.tmpfilename = "tmp_%i.gtf" % id(self)
98         infile = gzip.open( self.filename, "rb")
99         outfile = open( self.tmpfilename, "wb" )
100         outfile.write( infile.read() )
101         outfile.close()
102         infile.close()
103
104     def testIndexPreset( self ):
105         '''test indexing via preset.'''
106         
107         pysam.tabix_index( self.tmpfilename, preset = "gff" )
108         checkBinaryEqual( self.tmpfilename + ".gz", self.filename )
109         checkBinaryEqual( self.tmpfilename + ".gz.tbi", self.filename_idx )
110
111     def testCompression( self ):
112         '''see also issue 106'''
113         pysam.tabix_compress( self.tmpfilename, self.tmpfilename + ".gz" )
114         checkBinaryEqual( self.tmpfilename, self.tmpfilename + ".gz" )
115         
116     def tearDown( self ):
117         os.unlink( self.tmpfilename + ".gz" )
118         if os.path.exists( self.tmpfilename + ".gz.tbi" ):
119             os.unlink( self.tmpfilename + ".gz.tbi" )
120
121 class TestIteration( unittest.TestCase ):
122
123     filename = "example.gtf.gz" 
124
125     def setUp( self ):
126
127         self.tabix = pysam.Tabixfile( self.filename )
128         lines = []
129         inf = gzip.open( self.filename, "rb")
130         for line in inf:
131             line = line.decode('ascii')
132             if line.startswith("#"): continue
133             lines.append( line )
134         inf.close()
135         # creates index of contig, start, end, adds content without newline.
136         self.compare = [ 
137             (x[0][0], int(x[0][3]), int(x[0][4]), x[1])
138             for x in [ (y.split("\t"), y[:-1]) for y in lines ] ]
139                          
140     def getSubset( self, contig = None, start = None, end = None):
141         
142         if contig == None:
143             # all lines
144             subset = [ x[3] for x in self.compare ]
145         else:
146             if start != None and end == None:
147                 # until end of contig
148                 subset = [ x[3] for x in self.compare if x[0] == contig and x[2] > start ]
149             elif start == None and end != None:
150                 # from start of contig
151                 subset = [ x[3] for x in self.compare if x[0] == contig and x[1] <= end ]
152             elif start == None and end == None:
153                 subset = [ x[3] for x in self.compare if x[0] == contig ]
154             else:
155                 # all within interval
156                 subset = [ x[3] for x in self.compare if x[0] == contig and \
157                                min( x[2], end) - max(x[1], start) > 0 ]
158             
159         return subset
160
161     def checkPairwise( self, result, ref ):
162         '''check pairwise results.
163         '''
164         result.sort()
165         ref.sort()
166
167         a = set(result)
168         b = set(ref)
169
170         self.assertEqual( len(result), len(ref),
171                           "unexpected number of results: result=%i, expected ref=%i, differences are %s: %s" \
172                               % (len(result), len(ref),
173                                  a.difference(b), 
174                                  b.difference(a) ))
175
176         for x, d in enumerate( list(zip( result, ref ))):
177             self.assertEqual( d[0], d[1],
178                               "unexpected results in pair %i:\n'%s', expected\n'%s'" % \
179                                   (x, 
180                                    d[0], 
181                                    d[1]) )
182
183
184     def testAll( self ):
185         result = list(self.tabix.fetch())
186         ref = self.getSubset( )
187         self.checkPairwise( result, ref )
188
189     def testPerContig( self ):
190         for contig in ("chr1", "chr2", "chr1", "chr2" ):
191             result = list(self.tabix.fetch( contig ))
192             ref = self.getSubset( contig )
193             self.checkPairwise( result, ref )
194             
195     def testPerContigToEnd( self ):
196         
197         end = None
198         for contig in ("chr1", "chr2", "chr1", "chr2" ):
199             for start in range( 0, 200000, 1000):
200                 result = list(self.tabix.fetch( contig, start, end ))
201                 ref = self.getSubset( contig, start, end )
202                 self.checkPairwise( result, ref )
203
204     def testPerContigFromStart( self ):
205         
206         start = None
207         for contig in ("chr1", "chr2", "chr1", "chr2" ):
208             for end in range( 0, 200000, 1000):
209                 result = list(self.tabix.fetch( contig, start, end ))
210                 ref = self.getSubset( contig, start, end )
211                 self.checkPairwise( result, ref )
212
213     def testPerContig( self ):
214         
215         start, end  = None, None
216         for contig in ("chr1", "chr2", "chr1", "chr2" ):
217             result = list(self.tabix.fetch( contig, start, end ))
218             ref = self.getSubset( contig, start, end )
219             self.checkPairwise( result, ref )
220                 
221     def testPerInterval( self ):
222         
223         start, end  = None, None
224         for contig in ("chr1", "chr2", "chr1", "chr2" ):
225             for start in range( 0, 200000, 2000):
226                 for end in range( start, start + 2000, 500):
227                     result = list(self.tabix.fetch( contig, start, end ))
228                     ref = self.getSubset( contig, start, end )
229                     self.checkPairwise( result, ref )
230                 
231
232     def testInvalidIntervals( self ):
233         
234         self.assertRaises( ValueError, self.tabix.fetch, "chr1", 0, -10)
235         self.assertRaises( ValueError, self.tabix.fetch, "chr1", -10, 200)
236         self.assertRaises( ValueError, self.tabix.fetch, "chr1", 200, 0)
237         self.assertRaises( ValueError, self.tabix.fetch, "chr1", -10, -20)
238         self.assertRaises( ValueError, self.tabix.fetch, "chrUn" )
239
240     def testGetContigs( self ):
241         self.assertEqual( sorted(self.tabix.contigs), [b"chr1", b"chr2"] )
242         # check that contigs is read-only
243         self.assertRaises( AttributeError, setattr, self.tabix, "contigs", ["chr1", "chr2"] )
244
245     def testHeader( self ):
246         ref = []
247         inf = gzip.open( self.filename )
248         for x in inf:
249             x = x.decode("ascii")
250             if not x.startswith("#"): break
251             ref.append( x[:-1].encode('ascii') )
252         inf.close()
253
254         header = list( self.tabix.header )
255         self.assertEqual( ref, header )
256
257     def testReopening( self ):
258         '''test repeated opening of the same file.'''
259         def func1():
260             # opens any tabix file
261             inf = pysam.Tabixfile(self.filename)
262             return
263
264         for i in range(10000):
265             func1()
266
267 class TestParser( unittest.TestCase ):
268
269     filename = "example.gtf.gz" 
270
271     def setUp( self ):
272
273         self.tabix = pysam.Tabixfile( self.filename )
274         self.compare = loadAndConvert( self.filename )
275
276     def testRead( self ):
277
278         for x, r in enumerate(self.tabix.fetch( parser = pysam.asTuple() )):
279             self.assertEqual( self.compare[x], list(r) )
280             self.assertEqual( len(self.compare[x]), len(r) )
281
282             # test indexing
283             for c in range(0,len(r)):
284                 self.assertEqual( self.compare[x][c], r[c] )
285
286             # test slicing access
287             for c in range(0, len(r)-1):
288                 for cc in range(c+1, len(r)):
289                     self.assertEqual( self.compare[x][c:cc],
290                                       r[c:cc] )
291
292     def testWrite( self ):
293         
294         for x, r in enumerate(self.tabix.fetch( parser = pysam.asTuple() )):
295             self.assertEqual( self.compare[x], list(r) )
296             c = list(r)
297             for y in range(len(r)):
298                 r[y] = "test_%05i" % y
299                 c[y] = "test_%05i" % y
300             self.assertEqual( [x.encode("ascii") for x in c], list(r) )
301             self.assertEqual( "\t".join( c ), str(r) )
302             # check second assignment
303             for y in range(len(r)):
304                 r[y] = "test_%05i" % y
305             self.assertEqual( [x.encode("ascii") for x in c], list(r) )
306             self.assertEqual( "\t".join( c ), str(r) )
307
308     def testUnset( self ):
309         for x, r in enumerate(self.tabix.fetch( parser = pysam.asTuple() )):
310             self.assertEqual( self.compare[x], list(r) )
311             c = list(r)
312             e = [ x.decode('ascii') for x in r ]
313             for y in range(len(r)):
314                 r[y] = None
315                 c[y] = None
316                 e[y] = ""
317                 self.assertEqual( c, list(r) )
318                 self.assertEqual( "\t".join(e), str(r) )
319
320     def testIteratorCompressed( self ):
321         '''test iteration from compressed file.'''
322         with gzip.open( self.filename ) as infile:
323             for x, r in enumerate(pysam.tabix_iterator( infile, pysam.asTuple() )):
324                 self.assertEqual( self.compare[x], list(r) )
325                 self.assertEqual( len(self.compare[x]), len(r) )
326
327                 # test indexing
328                 for c in range(0,len(r)):
329                     self.assertEqual( self.compare[x][c], r[c] )
330
331                 # test slicing access
332                 for c in range(0, len(r)-1):
333                     for cc in range(c+1, len(r)):
334                         self.assertEqual( self.compare[x][c:cc],
335                                           r[c:cc] )
336
337     def testIteratorUncompressed( self ):
338         '''test iteration from uncompressed file.'''
339         tmpfilename = 'tmp_testIteratorUncompressed'
340         infile = gzip.open( self.filename, "rb")
341         outfile = open( tmpfilename, "wb" )
342         outfile.write( infile.read() )
343         outfile.close()
344         infile.close()
345
346         with open( tmpfilename ) as infile:
347             for x, r in enumerate(pysam.tabix_iterator( infile, pysam.asTuple() )):
348                 self.assertEqual( self.compare[x], list(r) )
349                 self.assertEqual( len(self.compare[x]), len(r) )
350
351                 # test indexing
352                 for c in range(0,len(r)):
353                     self.assertEqual( self.compare[x][c], r[c] )
354
355                 # test slicing access
356                 for c in range(0, len(r)-1):
357                     for cc in range(c+1, len(r)):
358                         self.assertEqual( self.compare[x][c:cc],
359                                           r[c:cc] )
360
361         os.unlink( tmpfilename )
362
363 class TestGTF( TestParser ):
364
365     def testRead( self ):
366
367         for x, r in enumerate(self.tabix.fetch( parser = pysam.asGTF() )):
368             c = self.compare[x]
369             self.assertEqual( len(c), len(r) )
370             self.assertEqual( list(c), list(r) )
371             self.assertEqual( c, splitToBytes( str(r) ) )
372             self.assertTrue( r.gene_id.startswith("ENSG") )
373             if r.feature != b'gene':
374                 self.assertTrue( r.transcript_id.startswith("ENST") )
375             self.assertEqual( c[0], r.contig )
376
377 class TestBed( unittest.TestCase ):
378     filename = "example.bed.gz"
379
380     def setUp( self ):
381
382         self.tabix = pysam.Tabixfile( self.filename)
383         self.compare = loadAndConvert( self.filename )
384
385     def testRead( self ):
386
387         for x, r in enumerate(self.tabix.fetch( parser = pysam.asBed() )):
388             c = self.compare[x]
389             self.assertEqual( len(c), len(r) )
390             self.assertEqual( c, splitToBytes( str(r) ) )
391             self.assertEqual( list(c), list(r) )
392             self.assertEqual( c[0], r.contig)
393             self.assertEqual( int(c[1]), r.start)
394             self.assertEqual( int(c[2]), r.end)
395
396     def testWrite( self ):
397
398         for x, r in enumerate(self.tabix.fetch( parser = pysam.asBed() )):
399             c = self.compare[x]
400             self.assertEqual( c, splitToBytes(str(r) ))
401             self.assertEqual( list(c), list(r) )
402
403             r.contig = "test"
404             self.assertEqual( b"test", r.contig)
405             self.assertEqual( b"test", r[0])
406
407             r.start += 1
408             self.assertEqual( int(c[1]) + 1, r.start )
409             self.assertEqual( str(int(c[1]) + 1), r[1].decode("ascii" ))
410
411             r.end += 1
412             self.assertEqual( int(c[2]) + 1, r.end )
413             self.assertEqual( str(int(c[2]) + 1), r[2].decode("ascii") )
414
415 class TestVCFFromTabix( unittest.TestCase ):
416
417     filename = "example.vcf40"
418
419     columns = ("contig", "pos", "id", 
420                "ref", "alt", "qual", 
421                "filter", "info", "format" )
422
423     def setUp( self ):
424         
425         self.tmpfilename = "tmp_%s.vcf" % id(self)
426         shutil.copyfile( self.filename, self.tmpfilename )
427         pysam.tabix_index( self.tmpfilename, preset = "vcf" )
428
429         self.tabix = pysam.Tabixfile( self.tmpfilename + ".gz" )
430         self.compare = loadAndConvert( self.filename )
431
432     def tearDown( self ):
433
434         os.unlink( self.tmpfilename + ".gz" )
435         os.unlink( self.tmpfilename + ".gz.tbi" )
436
437     def testRead( self ):
438         
439         ncolumns = len(self.columns) 
440
441         for x, r in enumerate(self.tabix.fetch( parser = pysam.asVCF() )):
442             c = self.compare[x]
443
444             for y, field in enumerate( self.columns ):
445                 # it is ok to have a missing format column
446                 if y == 8 and y == len(c): continue
447
448                 if field == "pos":
449                     self.assertEqual( int(c[y]) - 1, getattr( r, field ) )
450                     self.assertEqual( int(c[y]) - 1, r.pos )
451                 else:
452                     self.assertEqual( c[y], getattr( r, field ), 
453                                       "mismatch in field %s: %s != %s" %\
454                                           ( field,c[y], getattr( r, field ) ) )
455             if len(c) == 8:
456                 self.assertEqual( 0, len(r) )
457             else:
458                 self.assertEqual( len(c), len( r ) + ncolumns )
459             
460             for y in range(len(c) - ncolumns):
461                 self.assertEqual( c[ncolumns+y], r[y] )
462                 
463     def testWrite( self ):
464
465         ncolumns = len(self.columns) 
466
467         for x, r in enumerate(self.tabix.fetch( parser = pysam.asVCF() )):
468             c = self.compare[x]
469             # check unmodified string
470             cmp_string = str(r)
471             ref_string = "\t".join( [x.decode() for x in c] )
472
473             self.assertEqual( ref_string, cmp_string )
474             
475             # set fields and compare field-wise
476             for y, field in enumerate( self.columns ):
477                 # it is ok to have a missing format column
478                 if y == 8 and y == len(c): continue
479                 if field == "pos":
480                     rpos = getattr( r, field )
481                     self.assertEqual( int(c[y]) - 1, rpos )
482                     self.assertEqual( int(c[y]) - 1, r.pos )
483                     # increment pos by 1
484                     setattr( r, field, rpos + 1 )
485                     self.assertEqual( getattr( r, field ), rpos + 1 )
486                     c[y] = str(int(c[y]) + 1 ) 
487                 else:
488                     setattr( r, field, "test_%i" % y)
489                     c[y] = ("test_%i" % y).encode('ascii')
490                     self.assertEqual( c[y], getattr( r, field ), 
491                                       "mismatch in field %s: %s != %s" %\
492                                           ( field,c[y], getattr( r, field ) ) )
493
494             if len(c) == 8:
495                 self.assertEqual( 0, len(r) )
496             else:
497                 self.assertEqual( len(c), len( r ) + ncolumns )
498             
499             for y in range(len(c) - ncolumns):
500                 c[ncolumns+y] = ("test_%i" % y).encode('ascii')
501                 r[y] = ("test_%i" % y).encode('ascii')
502                 self.assertEqual( c[ncolumns+y], r[y] )
503
504
505 class TestVCFFromVCF( unittest.TestCase ):
506
507     filename = "example.vcf40"
508
509     columns = ("contig", "pos", "id", 
510                "ref", "alt", "qual", 
511                "filter", "info", "format" )
512
513     # tests failing while parsing
514     fail_on_parsing = ( (5, "Flag fields should not have a value"),
515                        (9, "aouao" ),
516                        (13, "aoeu" ),
517                        (18, "Error BAD_NUMBER_OF_PARAMETERS" ),
518                        (24, "Error HEADING_NOT_SEPARATED_BY_TABS" ) )
519
520     # tests failing on opening
521     fail_on_opening = ( (24, "Error HEADING_NOT_SEPARATED_BY_TABS" ),
522                      )
523
524     def setUp( self ):
525         
526         self.vcf = pysam.VCF()
527         self.compare = loadAndConvert( self.filename )
528
529     def testParsing( self ):
530
531         fn = os.path.basename( self.filename )
532         with open(self.filename) as f:
533
534             for x, msg in self.fail_on_opening:
535                 if "%i.vcf" % x == fn:
536                     self.assertRaises( ValueError, self.vcf.parse, f )
537                     return
538             else:
539                 iter = self.vcf.parse(f)
540
541             for x, msg in self.fail_on_parsing:
542                 if "%i.vcf" % x == fn:
543                     self.assertRaises( ValueError, list, iter )
544                     break
545                     # python 2.7
546                     # self.assertRaisesRegexp( ValueError, re.compile(msg), self.vcf.parse, f )
547             else:
548                 for ln in iter:
549                     pass
550
551 ############################################################################                   
552 # create a test class for each example vcf file.
553 # Two samples are created - 
554 # 1. Testing pysam/tabix access
555 # 2. Testing the VCF class
556 vcf_files = glob.glob( "vcf-examples/*.vcf" )
557
558 for vcf_file in vcf_files:
559     n = "VCFFromTabixTest_%s" % os.path.basename( vcf_file[:-4] )
560     globals()[n] = type( n, (TestVCFFromTabix,), dict( filename=vcf_file,) )
561     n = "VCFFromVCFTest_%s" % os.path.basename( vcf_file[:-4] )
562     globals()[n] = type( n, (TestVCFFromVCF,), dict( filename=vcf_file,) )
563
564 ############################################################################                   
565 class TestRemoteFileHTTP( unittest.TestCase):
566
567     url = "http://genserv.anat.ox.ac.uk/downloads/pysam/test/example.gtf.gz"
568     region = "chr1:1-1000"
569     local = "example.gtf.gz"
570
571     def testFetchAll( self ):
572         remote_file = pysam.Tabixfile(self.url, "r")  
573         remote_result = list(remote_file.fetch())
574         local_file = pysam.Tabixfile(self.local, "r")  
575         local_result = list(local_file.fetch())
576
577         self.assertEqual( len(remote_result), len(local_result) )
578         for x, y in zip(remote_result, local_result):
579             self.assertEqual( x, y )
580
581
582 if __name__ == "__main__":
583
584     unittest.main()
585
586