Port rdfinfer to rdflib (fixing a hidden bug)
[htsworkflow.git] / htsworkflow / util / rdfinfer.py
index 42957be4709a308ca8f563f6ab8b4cfbeedb043b..122b517f7a8d42aa6ccd75b7fd6675a30f037a80 100644 (file)
@@ -2,7 +2,8 @@ import logging
 import os
 import sys
 
-import RDF
+from rdflib import ConjunctiveGraph, BNode, Literal, URIRef
+from rdflib.plugins.sparql import prepareQuery
 
 from htsworkflow.util.rdfns import *
 from htsworkflow.util.rdfhelp import SCHEMAS_URL
@@ -16,8 +17,11 @@ class Infer(object):
     Provides a few default rules as methods starting with _rule_
     """
     def __init__(self, model):
+        if not isinstance(model, ConjunctiveGraph):
+            raise ValueError("Inferences require a ConjunctiveGraph")
+
         self.model = model
-        self._context = RDF.Node(RDF.Uri(INFER_URL))
+        self._context = URIRef(INFER_URL)
 
 
     def think(self, max_iterations=None):
@@ -77,11 +81,10 @@ class Infer(object):
           ?alias a ?class .
           ?obj a ?alias .
         }"""
-        query = RDF.SPARQLQuery(body)
-        for r in query.execute(self.model):
-            s = RDF.Statement(r['obj'], rdfNS['type'], r['class'])
+        for r in self.model.query(body):
+            s = (r['obj'], RDF['type'], r['class'], self._context)
             if s not in self.model:
-                self.model.append(s, self._context)
+                self.model.add(s)
 
     def _rule_subclass(self):
         """A subclass is a parent class
@@ -96,11 +99,10 @@ class Infer(object):
           ?subclass rdfs:subClassOf ?parent .
           ?obj a ?subclass .
         }"""
-        query = RDF.SPARQLQuery(body)
-        for r in query.execute(self.model):
-            s = RDF.Statement(r['obj'], rdfNS['type'], r['parent'])
+        for r in self.model.query(body):
+            s = (r['obj'], RDF['type'], r['parent'], self._context)
             if s not in self.model:
-                self.model.append(s, self._context)
+                self.model.add(s)
 
     def _rule_inverse_of(self):
         """Add statements computed with inverseOf
@@ -121,14 +123,10 @@ class Infer(object):
             ?reverse rdfs:domain ?object_type ;
                   rdfs:range ?subject_type .
         }"""
-        query = RDF.SPARQLQuery(body)
-
-        statements = []
-        for r in query.execute(self.model):
-            s = RDF.Statement(r['o'], r['reverse'], r['s'])
+        for r in self.model.query(body):
+            s = (r['o'], r['reverse'], r['s'], self._context)
             if s not in self.model:
-                self.model.append(s, self._context)
-
+                self.model.add(s)
 
     def _validate_types(self):
         body = """
@@ -145,10 +143,9 @@ class Infer(object):
           FILTER(?predicate != xhtmlv:stylesheet)
         }
         """
-        query = RDF.SPARQLQuery(body)
         errmsg = "Missing type for: {0}"
-        for r in query.execute(self.model):
-            yield errmsg.format(str(r['subject']))
+        for r in self.model.query(body):
+            yield errmsg.format(str(r[0]))
 
     def _validate_undefined_properties(self):
         """Find properties that aren't defined.
@@ -164,25 +161,24 @@ class Infer(object):
             OPTIONAL { ?predicate a ?predicate_class }
             FILTER(!bound(?predicate_class))
         }"""
-        query = RDF.SPARQLQuery(body)
         msg = "Undefined property in {0} {1} {2}"
-        for r in query.execute(self.model):
-            yield msg.format(str(r['subject']),
-                             str(r['predicate']),
-                             str(r['object']))
+        for r in self.model.query(body):
+            yield msg.format(r['subject'],
+                             r['predicate'],
+                             r['object'])
 
     def _validate_property_types(self):
         """Find resources that don't have a type
         """
-        property_template = """
+        property_query = prepareQuery("""
         prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#>
         prefix rdfs: <http://www.w3.org/2000/01/rdf-schema#>
 
-        select ?type
-        where {{
-            <{predicate}> a rdf:Property ;
-                        {space} ?type .
-        }}"""
+        select ?type ?predicate
+        where {
+            ?predicate a rdf:Property ;
+                        ?space ?type .
+        }""")
 
         def check_node_space(node, predicate, space, errmsg):
             """Check that a node conforms to it's allowable space of types.
@@ -193,53 +189,58 @@ class Infer(object):
             resource_error = "Expected resource for {0} in range {1}"
             type_error = "Type of {0} was {1} not {2}"
             # check domain
-            query = RDF.SPARQLQuery(property_template.format(
-                predicate=predicate.uri,
-                space=space))
             seen = set()
-            for r in query.execute(self.model):
+            errors = []
+            for i, r in enumerate(self.model.query(property_query,
+                                      initBindings={
+                                          'predicate': predicate,
+                                          'space': space})):
                 # Make sure we have a resource if we're expecting one
-                if r['type'] == rdfsNS['Resource']:
-                    if node.is_literal():
-                        return resource_error.format(str(node), space)
-                    continue
-                seen.add(str(r['type'].uri))
-                if node.is_literal():
-                    # literal is a generic type.
-                    nodetype = node.literal_value['datatype']
-                    if nodetype is None:
-                        # lets default to string
-                        nodetype = xsdNS['string'].uri
-                    if r['type'] == rdfsNS['Literal']:
-                        pass
-                    elif nodetype != r['type'].uri:
-                        return type_error.format(
-                            str(node), nodetype, r['type'])
-                # check that node is the expetected class type
-                check = RDF.Statement(node, rdfNS['type'], r['type'])
-                if self.model.contains_statement(check):
-                    return
-
-            # need the seen check, because we're surpressing checking
-            # rdfs:Resource types
-            if len(seen) > 0:
-                return errmsg + ",".join(seen)
-
+                expected_type = r['type']
+
+                if isinstance(node, Literal):
+                    if expected_type == RDFS['Literal']:
+                        return []
+                    elif node.datatype == expected_type:
+                        return []
+                    else:
+                        # not currently handling type hierarchy.
+                        # a integer could pass a range of decimal for instance.
+                        errors.append(
+                            "Type error: {} was type {}, expected {}".format(
+                                str(node),
+                                str(node.datatype),
+                                str(expected_type)))
+                elif expected_type == RDFS['Resource']:
+                    if isinstance(node, Literal):
+                        errors.append(resource_error.format(str(node), space))
+                    else:
+                        return []
+                else:
+                    check = (node, RDF['type'], expected_type)
+                    if check not in self.model:
+                        errors.append(errmsg + str(node) + ' was not a ' + str(expected_type))
+                    else:
+                        return []
+
+            return errors
+        ### End nested function
 
         wrong_domain_type = "Domain of {0} was not in:"
         wrong_range_type = "Range of {0} was not in:"
 
         count = 0
-        schema = RDF.Node(RDF.Uri(SCHEMAS_URL))
-        for s, context in self.model.as_stream_context():
+        schema = ConjunctiveGraph(identifier=SCHEMAS_URL)
+        for subject, predicate, obj, context in self.model.quads():
+            stmt = (subject, predicate, obj)
+
             if context == schema:
                 continue
             # check domain
-            msg = check_node_space(s.subject, s.predicate, 'rdfs:domain',
-                                   wrong_domain_type.format(str(s)))
-            if msg is not None: yield msg
+            for error in check_node_space(subject, predicate, RDFS.domain,
+                                          wrong_domain_type.format(str(stmt))):
+                yield error
             # check range
-            msg = check_node_space(s.object, s.predicate, 'rdfs:range',
-                                   wrong_range_type.format(str(s)))
-            if msg is not None: yield msg
-        return
+            for error in check_node_space(obj, predicate, RDFS.range,
+                                          wrong_range_type.format(str(stmt))):
+                yield error