Source code for _delb.xpath.ast

# Copyright (C) 2018-'22  Frank Sachsenheim
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import inspect
import operator
import sys
from functools import wraps
from textwrap import indent
from typing import (
    TYPE_CHECKING,
    cast,
    Any,
    Callable,
    Iterable,
    Iterator,
    List,
    NamedTuple,
    Optional,
    Sequence,
    Tuple,
)

from _delb.exceptions import InvalidCodePath, XPathEvaluationError, XPathParsingError
from _delb.names import Namespaces
from _delb.plugins import plugin_manager as _plugin_manager
from _delb.utils import _is_node_of_type, last

# REMOVE when support for Python 3.7 is dropped
if sys.version_info < (3, 8):
    cached_property = property
else:
    from functools import cached_property


if TYPE_CHECKING:
    from delb import NodeBase, ProcessingInstructionNode, TagNode


xpath_functions = _plugin_manager.xpath_functions


# helper


def ensure_prefix(func):
    @wraps(func)
    def wrapper(self, node, **kwargs):
        prefix = self.prefix

        if "context" in kwargs:
            namespaces = kwargs["context"].namespaces
        else:
            namespaces = kwargs["namespaces"]

        if prefix is not None and prefix not in namespaces:
            raise XPathEvaluationError(
                f"The namespace prefix `{prefix}` is unknown in the the evaluation "
                "context."
            )
        return func(self, node, **kwargs)

    return wrapper


def nested_repr(obj: Any) -> str:  # pragma: no cover
    result = f"{obj.__class__.__name__}(\n"
    for name, value in ((x, getattr(obj, x)) for x in obj.__slots__):
        result += f"  {name}="
        if isinstance(value, Iterable):
            result += (
                "[\n" + "\n".join(indent(repr(x), "    ") for x in value) + "\n]\n"
            )
        else:
            result += f"{value!r}\n"
    result += ")"
    return result


# structs


[docs]class EvaluationContext(NamedTuple): """ Instances of this type are passed to XPath functions in order to pass contextual information. """ node: NodeBase """ The node that is evaluated. """ position: int """ The node's position within all nodes that matched a location step's node test in order of the step's axis' direction. The first position is 1. """ size: int """ The number of all nodes all nodes that matched a location step's node test. """ namespaces: Namespaces """ A mapping of prefixes to namespaces that is used in the whole evaluation. """
# base classes for nodes class Node: def __eq__(self, other): return type(self) is type(other) and all( getattr(self, x) == getattr(other, x) for x in self.__slots__ ) def __repr__(self): return ( f"{self.__class__.__qualname__}(" f"{', '.join(f'{x}={getattr(self, x)!r}' for x in self.__slots__)})" ) class EvaluationNode(Node): def evaluate(self, node: NodeBase, context: EvaluationContext) -> bool: raise NotImplementedError @property def _derived_attributes(self): raise InvalidCodePath def _is_unambiguously_locatable(self) -> bool: return False class NodeTestNode(Node): def evaluate(self, node: NodeBase, namespaces: Namespaces) -> bool: raise NotImplementedError # aggregators class Axis(Node): __slots__ = ("generator",) def __init__(self, name: str): generator = getattr(self, name.replace("-", "_"), None) if generator is None: raise XPathParsingError(message="Invalid axis specifier.") self.generator = generator def __eq__(self, other): return ( isinstance(other, Axis) and self.generator.__name__ == other.generator.__name__ ) def __repr__(self): return f"{self.__class__.__name__}({self.generator.__name__})" def ancestor(self, node: NodeBase) -> Iterator[NodeBase]: yield from node.iterate_ancestors() def ancestor_or_self(self, node: NodeBase) -> Iterator[NodeBase]: yield node yield from node.iterate_ancestors() def evaluate(self, node: NodeBase, namespaces: Namespaces) -> Iterator[NodeBase]: yield from self.generator(node) def child(self, node: NodeBase) -> Iterator[NodeBase]: yield from node.iterate_children() def descendant(self, node: NodeBase) -> Iterator[NodeBase]: yield from node.iterate_descendants() def descendant_or_self(self, node: NodeBase) -> Iterator[NodeBase]: yield node yield from node.iterate_descendants() def following(self, node: NodeBase) -> Iterator[NodeBase]: yield from node.iterate_following() def following_sibling(self, node: NodeBase) -> Iterator[NodeBase]: yield from node.iterate_following_siblings() def parent(self, node: NodeBase) -> Iterator[NodeBase]: parent = node.parent if parent: yield parent def preceding(self, node: NodeBase) -> Iterator[NodeBase]: yield from node.iterate_preceding() def preceding_sibling(self, node: NodeBase) -> Iterator[NodeBase]: yield from node.iterate_preceding_siblings() def self(self, node: NodeBase) -> Iterator[NodeBase]: yield node class LocationPath(Node): __slots__ = ("absolute", "location_steps", "parent_path") def __init__(self, location_steps: Iterable[LocationStep], absolute: bool = False): location_steps = tuple(location_steps) self.parent_path = ( LocationPath(location_steps=location_steps[:-1], absolute=absolute) if len(location_steps) > 1 else None ) self.location_steps = location_steps self.absolute = absolute def __repr__(self): return nested_repr(self) def evaluate(self, node: NodeBase, namespaces: Namespaces) -> Iterator[NodeBase]: if self.parent_path: parent_paths_result_generator = self.parent_path.evaluate( node=node, namespaces=namespaces ) yield from self.location_steps[-1].evaluate( node_set=parent_paths_result_generator, namespaces=namespaces ) elif self.absolute: first_node = node assert first_node is not None root = last(first_node.iterate_ancestors()) or first_node yield from self.location_steps[0].evaluate( node_set=(root,), namespaces=namespaces ) else: yield from self.location_steps[0].evaluate( node_set=(node,), namespaces=namespaces ) def _is_unambiguously_locatable(self) -> bool: return all(s._is_unambiguously_locatable() for s in self.location_steps) class LocationStep(Node): __slots__ = ("axis", "node_test", "predicates") def __init__( self, axis: Axis, node_test: NodeTestNode, predicates: Sequence[EvaluationNode] = (), ): self.axis = axis self.node_test = node_test self.predicates = tuple(predicates) @cached_property def _anders_predicates(self) -> "BooleanOperator": predicates = list(self.predicates) right = predicates.pop() while predicates: left = predicates.pop() right = BooleanOperator(operator=operator.and_, left=left, right=right) assert isinstance(right, BooleanOperator) return right @cached_property def _derived_attributes(self) -> List[Tuple[Optional[str], str, str]]: predicates_count = len(self.predicates) if predicates_count == 0: return [] elif predicates_count == 1: return self.predicates[0]._derived_attributes else: return self._anders_predicates._derived_attributes def evaluate( self, node_set: Iterable[NodeBase], namespaces: Namespaces ) -> Iterator[NodeBase]: yielded_nodes = set() for node in node_set: for result_node in self._evaluate(node=node, namespaces=namespaces): _id = id(result_node) if _id not in yielded_nodes: yielded_nodes.add(_id) yield result_node def _evaluate(self, node: NodeBase, namespaces: Namespaces) -> Sequence[NodeBase]: node_test = self.node_test predicates = self.predicates if not predicates: return tuple( n for n in self.axis.evaluate(node=node, namespaces=namespaces) if node_test.evaluate(node=n, namespaces=namespaces) ) candidates = [ n for n in self.axis.evaluate(node=node, namespaces=namespaces) if node_test.evaluate(node=n, namespaces=namespaces) ] for predicate in predicates: size = len(candidates) next_candidates = [] for position, candidate in enumerate(candidates, start=1): if predicate.evaluate( node=candidate, context=EvaluationContext( node=candidate, position=position, size=size, namespaces=namespaces, ), ): next_candidates.append(candidate) candidates = next_candidates return candidates def _is_unambiguously_locatable(self) -> bool: if not ( self.axis.generator.__name__ == "child" and isinstance(self.node_test, NameMatchTest) ): return False predicates_count = len(self.predicates) if predicates_count == 0: return True elif predicates_count == 1: return self.predicates[0]._is_unambiguously_locatable() else: return self._anders_predicates._is_unambiguously_locatable() class XPathExpression(Node): __slots__ = ("location_paths",) def __init__(self, location_paths: List[LocationPath]): self.location_paths = location_paths def __repr__(self): return nested_repr(self) def evaluate(self, node: NodeBase, namespaces: Namespaces) -> Iterator[NodeBase]: yielded_nodes: set[int] = set() for path in self.location_paths: for result in path.evaluate(node=node, namespaces=namespaces): _id = id(result) if _id not in yielded_nodes: yielded_nodes.add(_id) yield result @cached_property def _is_unambiguously_locatable(self) -> bool: return ( len(self.location_paths) == 1 and self.location_paths[0]._is_unambiguously_locatable() ) # node tests class NameMatchTest(NodeTestNode): __slots__ = ("local_name", "prefix") def __init__(self, prefix: Optional[str], local_name: str): self.prefix = prefix self.local_name = local_name @ensure_prefix def evaluate(self, node: NodeBase, namespaces) -> bool: if not _is_node_of_type(node, "TagNode"): return False node = cast("TagNode", node) if (self.prefix or None in namespaces) and node.namespace != namespaces.get( self.prefix ): return False return node.local_name == self.local_name class NameStartTest(NodeTestNode): __slots__ = ("prefix", "start") def __init__(self, prefix: Optional[str], start: str): self.prefix = prefix self.start = start @ensure_prefix def evaluate(self, node: NodeBase, namespaces) -> bool: if not _is_node_of_type(node, "TagNode"): return False node = cast("TagNode", node) if (self.prefix or None in namespaces) and node.namespace != namespaces.get( self.prefix ): return False return node.local_name.startswith(self.start) class NodeTypeTest(NodeTestNode): __slots__ = ("type_name",) def __init__(self, type_name: str): self.type_name = type_name def evaluate(self, node: NodeBase, namespaces) -> bool: return _is_node_of_type(node, self.type_name) class ProcessingInstructionTest(NodeTypeTest): __slots__ = ("target", "type_name") def __init__(self, target: str): super().__init__("ProcessingInstructionNode") self.target = target def evaluate(self, node: NodeBase, namespaces) -> bool: if not super().evaluate(node=node, namespaces=namespaces): return False assert _is_node_of_type(node, "ProcessingInstructionNode") return cast("ProcessingInstructionNode", node).target == self.target # evaluation class AnyValue(EvaluationNode): __slots__ = ("value",) def __init__(self, value: Any): self.value = value def evaluate(self, node: NodeBase, context: EvaluationContext) -> Any: return self.value class AttributeValue(EvaluationNode): __slots__ = ("local_name", "prefix") def __init__(self, prefix: Optional[str], name: str): self.prefix = prefix self.local_name = name @ensure_prefix def evaluate(self, node: NodeBase, context: EvaluationContext) -> Optional[str]: if not _is_node_of_type(node, "TagNode"): return None node = cast("TagNode", node) result = node.attributes[ context.namespaces.get(self.prefix) : self.local_name # type: ignore ] return "" if result is None else result.value class BooleanOperator(EvaluationNode): __slots__ = ("left", "operator", "right") def __init__( self, operator: Callable, left: EvaluationNode, right: EvaluationNode, ): self.operator = operator self.left = left self.right = right @property def _derived_attributes(self) -> List[Tuple[Optional[str], str, str]]: if self.operator is operator.and_: return self.left._derived_attributes + self.right._derived_attributes elif self.operator is operator.eq: left, right = self.left, self.right if isinstance(left, AttributeValue): assert isinstance(right, AnyValue) return [ (left.prefix, left.local_name, right.value), ] else: assert isinstance(left, AnyValue) and isinstance(right, AttributeValue) return [ (right.prefix, right.local_name, left.value), ] raise InvalidCodePath def evaluate(self, node: NodeBase, context: EvaluationContext) -> Any: return self.operator( self.left.evaluate(node=node, context=context), self.right.evaluate(node=node, context=context), ) def _is_unambiguously_locatable(self) -> bool: if self.operator is operator.and_: return ( self.left._is_unambiguously_locatable() and self.right._is_unambiguously_locatable() ) elif self.operator is operator.eq: return ( isinstance(self.left, AttributeValue) and ( isinstance(self.right, AnyValue) and isinstance(self.right.value, str) ) ) or ( (isinstance(self.left, AnyValue) and isinstance(self.left.value, str)) and isinstance(self.right, AttributeValue) ) else: return False class Function(EvaluationNode): __slots__ = ("arguments", "function") def __init__(self, name: str, arguments: Sequence[EvaluationNode]): function = xpath_functions.get(name) if function is None: raise XPathParsingError(message=f"Unknown function: `{name}`") parameters = inspect.signature(function).parameters if ( len(parameters) > 1 and tuple(parameters.values())[-1].kind != inspect.Parameter.VAR_POSITIONAL and len(parameters) != len(arguments) + 1 ): raise XPathParsingError( message=f"Arguments to function `{name}` don't match its signature." ) self.function = function self.arguments = tuple(arguments) def __eq__(self, other): return ( isinstance(other, Function) and self.function is other.function and self.arguments == other.arguments ) def evaluate(self, node: NodeBase, context: EvaluationContext) -> Any: return self.function( context, *(x.evaluate(node=node, context=context) for x in self.arguments) ) class HasAttribute(EvaluationNode): __slots__ = ("local_name", "prefix") def __init__(self, prefix: Optional[str], local_name: str): self.prefix = prefix self.local_name = local_name @ensure_prefix def evaluate(self, node: NodeBase, context: EvaluationContext) -> bool: if not _is_node_of_type(node, "TagNode"): return False node = cast("TagNode", node) return ( node.attributes[ context.namespaces.get(self.prefix) : self.local_name # type: ignore ] is not None ) __all__ = ( Axis.__name__, LocationPath.__name__, LocationStep.__name__, NameMatchTest.__name__, XPathExpression.__name__, )