Refactor out code to submit a single file from a loop.
[htsworkflow.git] / htsworkflow / submission / aws_submission.py
1 """Partially through ENCODE3 the DCC switched to needing to upload via AWS
2 """
3
4 import logging
5 import json
6 import os
7 from pprint import pformat, pprint
8 import subprocess
9 import time
10
11 from requests.exceptions import HTTPError
12
13 from htsworkflow.submission.submission import Submission
14 from .encoded import ENCODED
15
16 from django.template import Context, loader
17
18 LOGGER = logging.getLogger(__name__)
19
20 class AWSSubmission(Submission):
21     def __init__(self, name, model, encode_host, lims_host):
22         """Create a AWS based submission
23
24         :Parameters:
25           - `name`: Name of submission
26           - `model`: librdf model reference
27           - `host`: hostname for library pages.
28         """
29         super(AWSSubmission, self).__init__(name, model, lims_host)
30         self.encode = ENCODED(encode_host)
31         self.encode.load_netrc()
32
33         self._replicates = {}
34         self._files = {}
35
36
37     def check_upload(self, results_map):
38         tocheck = []
39         # phase one download data
40         for an_analysis in self.analysis_nodes(results_map):
41             for metadata in self.get_metadata(an_analysis):
42                 filename = self.make_upload_filename(metadata)
43                 if os.path.exists(filename):
44                     with open(filename, 'rt') as instream:
45                         uploaded = json.load(instream)
46                     tocheck.append({
47                         'submitted_file_name': uploaded['submitted_file_name'],
48                         'md5sum': uploaded['md5sum']
49                     })
50                     self.update_replicates(uploaded)
51
52         # phase 2 make sure submitted file is there
53         md5sums = set((self._files[f]['md5sum'] for f in self._files))
54         submitted_file_names = set(
55             (self._files[f]['submitted_file_name'] for f in self._files)
56         )
57         errors_detected = 0
58         for row in tocheck:
59             error = []
60             if row['submitted_file_name'] not in submitted_file_names:
61                 error.append('!file_name')
62             if row['md5sum'] not in md5sums:
63                 error.append('!md5sum')
64             if error:
65                 print("{} failed {} checks".format(
66                     row['submitted_file_name'],
67                     ', '.join(error)
68                 ))
69                 errors_detected += 1
70
71         if not errors_detected:
72             print('No errors detected')
73
74     def update_replicates(self, metadata):
75         replicate_id = metadata['replicate']
76         if replicate_id not in self._replicates:
77             LOGGER.debug("Downloading %s", replicate_id)
78             try:
79                 rep = self.encode.get_json(replicate_id)
80
81                 self._replicates[replicate_id] = rep
82                 for file_id in rep['experiment']['files']:
83                     self.update_files(file_id)
84             except HTTPError as err:
85                 print(err.response, dir(err.response))
86                 if err.response.status_code == 404:
87                     print('Unable to find {} {}'.format(
88                         replicate_id,
89                         metadata['submitted_file_name'])
90                     )
91                 else:
92                     raise err
93
94     def update_files(self, file_id):
95         if file_id not in self._files:
96             LOGGER.debug("Downloading %s", file_id)
97             try:
98                 file_object = self.encode.get_json(file_id)
99                 self._files[file_id] = file_object
100             except HTTPError as err:
101                 if err.response.status_code == 404:
102                     print('unable to find {}'.format(file_id))
103                 else:
104                     raise err
105
106     def upload(self, results_map, dry_run=False):
107         for an_analysis in self.analysis_nodes(results_map):
108             for metadata in self.get_metadata(an_analysis):
109                 upload_file_metadata(self.encode, metadata, dry_run)
110
111     def get_metadata(self, analysis_node):
112         # convert our model names to encode project aliases
113         platform_alias = {
114             'Illumina HiSeq 2500': 'ENCODE:HiSeq2500'
115         }
116         run_type_alias = {
117             'Single': 'single-ended',
118             'Paired': 'paired-ended',
119         }
120         query_template = loader.get_template('aws_metadata.sparql')
121
122         context = Context({
123             'submission': str(analysis_node.uri),
124             'submissionSet': str(self.submissionSetNS[''].uri),
125             })
126         results = self.execute_query(query_template, context)
127         LOGGER.info("scanned %s for results found %s",
128                     str(analysis_node), len(results))
129
130         # need to adjust the results of the query slightly.
131         for row in results:
132             if 'platform' in row:
133                 row['platform'] = platform_alias[row['platform']]
134             if 'read_length' in row:
135                 row['read_length'] = int(row['read_length'])
136             if 'run_type' in row:
137                 row['run_type'] = run_type_alias[str(row['run_type'])]
138             flowcell_details = {}
139             for term in ['machine', 'flowcell', 'lane', 'barcode']:
140                 if term in row:
141                     value = str(row[term])
142                     flowcell_details[term] = value
143                     del row[term]
144             if len(flowcell_details) > 0:
145                 row['flowcell_details'] = [flowcell_details]
146
147         return results
148
149 def run_aws_cp(pathname, creds):
150     env = os.environ.copy()
151     env.update({
152         'AWS_ACCESS_KEY_ID': creds['access_key'],
153         'AWS_SECRET_ACCESS_KEY': creds['secret_key'],
154         'AWS_SECURITY_TOKEN': creds['session_token'],
155     })
156     start = time.time()
157     try:
158         subprocess.check_call(['aws', 's3', 'cp', pathname, creds['upload_url']], env=env)
159     except subprocess.CalledProcessError as e:
160         LOGGER.error('Upload of %s failed with exit code %d', pathname, e.returncode)
161         return
162     else:
163         end = time.time()
164         LOGGER.info('Upload of %s finished in %.2f seconds',
165                     pathname,
166                     end-start)
167
168
169 def upload_file(encode, metadata, dry_run=True):
170     """Upload a file to the DCC
171     """
172     encode.validate(metadata, 'file')
173
174     if dry_run:
175         LOGGER.info(json.dumps(metadata, indent=4, sort_keys=True))
176         return
177
178     upload = make_upload_filename(metadata)
179     if not os.path.exists(upload):
180         with open(upload, 'w') as outstream:
181             json.dump(metadata, outstream, indent=4, sort_keys=True)
182         LOGGER.debug(json.dumps(metadata, indent=4, sort_keys=True))
183
184         response = encode.post_json('/file', metadata)
185         LOGGER.info(json.dumps(response, indent=4, sort_keys=True))
186
187         item = response['@graph'][0]
188         creds = item['upload_credentials']
189         run_aws_cp(metadata['submitted_file_name'], creds)
190     else:
191         LOGGER.info('%s already uploaded',
192                     metadata['submitted_file_name'])
193
194
195 def make_upload_filename(metadata):
196     return metadata['submitted_file_name'] + '.upload'