Test htsworkflow under several different django & python versions
[htsworkflow.git] / encode_submission / importencoded.py
1 import argparse
2 import logging
3 import collections
4 import json
5 import pprint
6
7 logger = logging.getLogger('ImportEncoded')
8
9 from sqlalchemy.ext.declarative import declarative_base
10 from sqlalchemy import Column, Integer, String, create_engine
11 from sqlalchemy.dialects.postgresql import UUID, JSONB
12 from sqlalchemy.orm import sessionmaker
13
14 Base = declarative_base()
15
16 class Item(Base):
17     __tablename__ = 'item'
18
19     uuid = Column(UUID, primary_key=True)
20     uri = Column(String)
21     object_type = Column(String)
22     payload = Column(JSONB)
23
24
25 def create_item(row):
26     uuid = row['uuid']
27     uri = row['@id']
28     object_type = row['@type'][0]
29
30     payload = row.copy()
31     del payload['@id']
32     del payload['uuid']
33     del payload['@type']
34
35     return Item(uri=uri, uuid=uuid, object_type=object_type, payload=payload)
36
37
38 def create_session(engine):
39     session = sessionmaker(bind=engine)
40     return session
41     
42 def load_data(session, graph):
43     seen_pkeys = set()
44     duplicates = {}
45
46     for i, row in enumerate(graph):
47         obj_id = row['uuid']
48         if obj_id not in seen_pkeys:
49             session.add(create_item(row))
50             seen_pkeys.add(obj_id)
51         else:
52             duplicates.setdefault(obj_id, []).append(row)
53
54         if (i + 1) % 10000 == 0:
55             session.commit()
56             print("{} of {}".format(i+1, len(graph)))
57
58     return duplicates
59
60 def load_dump(filename):
61     logger.info("Creating schema")
62     engine = create_engine('postgresql://felcat.caltech.edu/encoded')
63     Base.metadata.create_all(engine)
64     sessionfactory = sessionmaker(bind=engine)
65     session = sessionfactory()
66     
67     logger.info("Parsing %s", filename)
68     with open(filename, 'r') as instream:
69         data = json.load(instream)
70
71     graph = data['@graph']
72     logging.info("Loading")
73     collisions = load_data(session, graph)
74
75     with open('bad.txt', 'a') as outstream:
76         outstream.write(pprint.pformat(collisions))
77
78 def main(cmdline=None):
79     parser = argparse.ArgumentParser()
80     parser.add_argument('filename', nargs=1, help='json dump file to load')
81
82     args = parser.parse_args(cmdline)
83
84     logging.basicConfig(level=logging.DEBUG)
85     for filename in args.filename:
86         load_dump(filename)
87
88 if __name__ == '__main__':
89     main()
90