4cf98cb801afe26d222281124ddf8a1198fc4ad2
[htsworkflow.git] / htsworkflow / submission / aws_submission.py
1 """Partially through ENCODE3 the DCC switched to needing to upload via AWS
2 """
3
4 import logging
5 import json
6 import os
7 from pprint import pformat, pprint
8 import string
9 import subprocess
10 import time
11 import re
12
13 import jsonschema
14 import RDF
15 from requests.exceptions import HTTPError
16
17 from htsworkflow.submission.submission import Submission
18 from .encoded import ENCODED
19 from htsworkflow.util.rdfhelp import \
20      fromTypedNode, \
21      geoSoftNS, \
22      submissionOntology
23 from htsworkflow.util.ucsc import bigWigInfo
24
25 from django.conf import settings
26 from django.template import Context, loader
27
28 LOGGER = logging.getLogger(__name__)
29
30 class AWSSubmission(Submission):
31     def __init__(self, name, model, encode_host, lims_host):
32         """Create a AWS based submission
33
34         :Parameters:
35           - `name`: Name of submission
36           - `model`: librdf model reference
37           - `host`: hostname for library pages.
38         """
39         super(AWSSubmission, self).__init__(name, model, lims_host)
40         self.encode = ENCODED(encode_host)
41         self.encode.load_netrc()
42
43         self._replicates = {}
44         self._files = {}
45
46
47     def check_upload(self, results_map):
48         tocheck = []
49         # phase one download data
50         for an_analysis in self.analysis_nodes(results_map):
51             for metadata in self.get_metadata(an_analysis):
52                 filename = self.make_upload_filename(metadata)
53                 if os.path.exists(filename):
54                     with open(filename, 'rt') as instream:
55                         uploaded = json.load(instream)
56                     tocheck.append({
57                         'submitted_file_name': uploaded['submitted_file_name'],
58                         'md5sum': uploaded['md5sum']
59                     })
60                     self.update_replicates(uploaded)
61
62         # phase 2 make sure submitted file is there
63         md5sums = set((self._files[f]['md5sum'] for f in self._files))
64         submitted_file_names = set(
65             (self._files[f]['submitted_file_name'] for f in self._files)
66         )
67         errors_detected = 0
68         for row in tocheck:
69             error = []
70             if row['submitted_file_name'] not in submitted_file_names:
71                 error.append('!file_name')
72             if row['md5sum'] not in md5sums:
73                 error.append('!md5sum')
74             if error:
75                 print("{} failed {} checks".format(
76                     row['submitted_file_name'],
77                     ', '.join(error)
78                 ))
79                 errors_detected += 1
80
81         if not errors_detected:
82             print('No errors detected')
83
84     def update_replicates(self, metadata):
85         replicate_id = metadata['replicate']
86         if replicate_id not in self._replicates:
87             LOGGER.debug("Downloading %s", replicate_id)
88             try:
89                 rep = self.encode.get_json(replicate_id)
90
91                 self._replicates[replicate_id] = rep
92                 for file_id in rep['experiment']['files']:
93                     self.update_files(file_id)
94             except HTTPError as err:
95                 print(err.response, dir(err.response))
96                 if err.response.status_code == 404:
97                     print('Unable to find {} {}'.format(
98                         replicate_id,
99                         metadata['submitted_file_name'])
100                     )
101                 else:
102                     raise err
103
104     def update_files(self, file_id):
105         if file_id not in self._files:
106             LOGGER.debug("Downloading %s", file_id)
107             try:
108                 file_object = self.encode.get_json(file_id)
109                 self._files[file_id] = file_object
110             except HTTPError as err:
111                 if err.response.status_code == 404:
112                     print('unable to find {}'.format(file_id))
113                 else:
114                     raise err
115
116     def upload(self, results_map, dry_run=False):
117         for an_analysis in self.analysis_nodes(results_map):
118             for metadata in self.get_metadata(an_analysis):
119                 metadata['@type'] = ['file']
120                 self.encode.validate(metadata)
121                 del metadata['@type']
122
123                 if dry_run:
124                     LOGGER.info(json.dumps(metadata, indent=4, sort_keys=True))
125                     continue
126
127                 upload = self.make_upload_filename(metadata)
128                 if not os.path.exists(upload):
129                     with open(upload, 'w') as outstream:
130                         json.dump(metadata, outstream, indent=4, sort_keys=True)
131                     LOGGER.debug(json.dumps(metadata, indent=4, sort_keys=True))
132
133                     response = self.encode.post_json('/file', metadata)
134                     LOGGER.info(json.dumps(response, indent=4, sort_keys=True))
135
136                     item = response['@graph'][0]
137                     creds = item['upload_credentials']
138                     run_aws_cp(metadata['submitted_file_name'], creds)
139                 else:
140                     LOGGER.info('%s already uploaded',
141                                 metadata['submitted_file_name'])
142
143
144     def get_metadata(self, analysis_node):
145         # convert our model names to encode project aliases
146         platform_alias = {
147             'Illumina HiSeq 2500': 'ENCODE:HiSeq2500'
148         }
149         query_template = loader.get_template('aws_metadata.sparql')
150
151         context = Context({
152             'submission': str(analysis_node.uri),
153             'submissionSet': str(self.submissionSetNS[''].uri),
154             })
155         results = self.execute_query(query_template, context)
156         LOGGER.info("scanned %s for results found %s",
157                     str(analysis_node), len(results))
158
159         # need to adjust the results of the query slightly.
160         for row in results:
161             if 'platform' in row:
162                 row['platform'] = platform_alias[row['platform']]
163             if 'read_length' in row:
164                 row['read_length'] = int(row['read_length'])
165             flowcell_details = {}
166             for term in ['machine', 'flowcell', 'lane', 'barcode']:
167                 if term in row:
168                     value = str(row[term])
169                     flowcell_details[term] = value
170                     del row[term]
171             if len(flowcell_details) > 0:
172                 row['flowcell_details'] = [flowcell_details]
173
174         return results
175
176     def make_upload_filename(self, metadata):
177         return metadata['submitted_file_name'] + '.upload'
178
179 def run_aws_cp(pathname, creds):
180     env = os.environ.copy()
181     env.update({
182         'AWS_ACCESS_KEY_ID': creds['access_key'],
183         'AWS_SECRET_ACCESS_KEY': creds['secret_key'],
184         'AWS_SECURITY_TOKEN': creds['session_token'],
185     })
186     start = time.time()
187     try:
188         subprocess.check_call(['aws', 's3', 'cp', pathname, creds['upload_url']], env=env)
189     except subprocess.CalledProcessError as e:
190         LOGGER.error('Upload of %s failed with exit code %d', pathname, e.returncode)
191         return
192     else:
193         end = time.time()
194         LOGGER.info('Upload of %s finished in %.2f seconds',
195                     pathname,
196                     end-start)