Source code for _delb.utils

# Copyright (C) 2018-'22  Frank Sachsenheim
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

import re
import sys
from collections import defaultdict
from copy import copy
from functools import partial
from string import ascii_lowercase
from typing import (
    TYPE_CHECKING,
    cast,
    Any,
    Callable,
    Iterable,
    Iterator,
    Optional,
    Sequence,
    Union,
)
from warnings import warn

from lxml import etree

from _delb.typing import Filter

if TYPE_CHECKING:
    from _delb.nodes import NodeBase, TagNode

_crunch_whitespace = partial(re.compile(r"\s+").sub, " ")


class _Nodes_Sorter:
    def __init__(self):
        self.__node = None
        self.__items = defaultdict(_Nodes_Sorter)

    def add(self, path: Sequence[int], node: TagNode):
        assert _is_node_of_type(node, "TagNode")
        if path:
            self.__items[path[0]].add(path[1:], node)
        else:
            self.__node = node

    def emit(self) -> Iterator[NodeBase]:
        if self.__node is not None:
            yield self.__node
        for index in sorted(self.__items):
            yield from self.__items[index].emit()


class _StringMixin:  # pragma: no cover
    # copied from CPython 3.10.0's stdlib collections.UserString and adjusted

    __slots__ = ()

    def __str__(self):
        return str(self._data)

    def __int__(self):
        return int(self._data)

    def __float__(self):
        return float(self._data)

    def __complex__(self):
        return complex(self._data)

    def __hash__(self):
        return hash(self._data)

    def __eq__(self, string):
        return self._data == string

    def __lt__(self, string):
        return self._data < string

    def __le__(self, string):
        return self._data <= string

    def __gt__(self, string):
        return self._data > string

    def __ge__(self, string):
        return self._data >= string

    def __contains__(self, char):
        return char in self._data

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index):
        return self._data[index]

    def __add__(self, other):
        if isinstance(other, str):
            return self._data + other
        return self._data + str(other)

    def __radd__(self, other):
        if isinstance(other, str):
            return other + self._data
        return str(other) + self._data

    def __mul__(self, n):
        return self._data * n

    __rmul__ = __mul__

    def __mod__(self, args):
        return self._data % args

    def __rmod__(self, template):
        return str(template) % self

    def capitalize(self):
        return self._data.capitalize()

    def casefold(self):
        return self._data.casefold()

    def center(self, width, *args):
        return self._data.center(width, *args)

    def count(self, sub, start=0, end=sys.maxsize):
        return self._data.count(sub, start, end)

    def removeprefix(self, prefix):
        return self._data.removeprefix(prefix)

    def removesuffix(self, suffix):
        return self._data.removesuffix(suffix)

    def encode(self, encoding="utf-8", errors="strict"):
        encoding = "utf-8" if encoding is None else encoding
        errors = "strict" if errors is None else errors
        return self._data.encode(encoding, errors)

    def endswith(self, suffix, start=0, end=sys.maxsize):
        return self._data.endswith(suffix, start, end)

    def expandtabs(self, tabsize=8):
        return self._data.expandtabs(tabsize)

    def find(self, sub, start=0, end=sys.maxsize):
        return self._data.find(sub, start, end)

    def format(self, *args, **kwds):
        return self._data.format(*args, **kwds)

    def format_map(self, mapping):
        return self._data.format_map(mapping)

    def index(self, sub, start=0, end=sys.maxsize):
        return self._data.index(sub, start, end)

    def isalpha(self):
        return self._data.isalpha()

    def isalnum(self):
        return self._data.isalnum()

    def isascii(self):
        return self._data.isascii()

    def isdecimal(self):
        return self._data.isdecimal()

    def isdigit(self):
        return self._data.isdigit()

    def isidentifier(self):
        return self._data.isidentifier()

    def islower(self):
        return self._data.islower()

    def isnumeric(self):
        return self._data.isnumeric()

    def isprintable(self):
        return self._data.isprintable()

    def isspace(self):
        return self._data.isspace()

    def istitle(self):
        return self._data.istitle()

    def isupper(self):
        return self._data.isupper()

    def join(self, seq):
        return self._data.join(seq)

    def ljust(self, width, *args):
        return self._data.ljust(width, *args)

    def lower(self):
        return self._data.lower()

    def lstrip(self, chars=None):
        return self._data.lstrip(chars)

    def partition(self, sep):
        return self._data.partition(sep)

    def replace(self, old, new, maxsplit=-1):
        return self._data.replace(old, new, maxsplit)

    def rfind(self, sub, start=0, end=sys.maxsize):
        return self._data.rfind(sub, start, end)

    def rindex(self, sub, start=0, end=sys.maxsize):
        return self._data.rindex(sub, start, end)

    def rjust(self, width, *args):
        return self._data.rjust(width, *args)

    def rpartition(self, sep):
        return self._data.rpartition(sep)

    def rstrip(self, chars=None):
        return self._data.rstrip(chars)

    def split(self, sep=None, maxsplit=-1):
        return self._data.split(sep, maxsplit)

    def rsplit(self, sep=None, maxsplit=-1):
        return self._data.rsplit(sep, maxsplit)

    def splitlines(self, keepends=False):
        return self._data.splitlines(keepends)

    def startswith(self, prefix, start=0, end=sys.maxsize):
        return self._data.startswith(prefix, start, end)

    def strip(self, chars=None):
        return self._data.strip(chars)

    def swapcase(self):
        return self._data.swapcase()

    def title(self):
        return self._data.title()

    def translate(self, *args):
        return self._data.translate(*args)

    def upper(self):
        return self._data.upper()

    def zfill(self, width):
        return self._data.zfill(width)


def _better_call(f: Union[Callable, property]) -> Callable:
    def decorator(d: Callable) -> Callable:
        def wrapper(*args, **kwargs):
            """:meta category: deprecated"""
            warn(
                f"{d.__name__} is deprecated, use {f.__name__} instead.",
                category=DeprecationWarning,
            )
            return f(*args, **kwargs)

        d.__doc__ = ":meta: private"
        return wrapper

    return decorator


def _better_yield(f: Callable) -> Callable:
    def decorator(d: Callable) -> Callable:
        def wrapper(*args, **kwargs):
            """:meta category: deprecated"""
            warn(
                f"{d.__name__} is deprecated, use {f.__name__} instead.",
                category=DeprecationWarning,
            )
            yield from f(*args, **kwargs)

        return wrapper

    return decorator


def _copy_root_siblings(source: etree._Element, target: etree._Element):
    stack = []
    current_element = source.getprevious()
    while current_element is not None:
        stack.append(current_element)
        current_element = current_element.getprevious()
    while stack:
        target.addprevious(copy(stack.pop()))

    stack = []
    current_element = source.getnext()
    while current_element is not None:
        stack.append(current_element)
        current_element = current_element.getnext()
    while stack:
        target.addnext(copy(stack.pop()))


[docs]def first(iterable: Iterable) -> Optional[Any]: """ Returns the first item of the given :term:`iterable` or ``None`` if it's empty. Note that the first item is consumed when the iterable is an :term:`iterator`. """ if isinstance(iterable, Iterator): try: return next(iterable) except StopIteration: return None elif isinstance(iterable, Sequence): return iterable[0] if len(iterable) else None else: raise TypeError
[docs]def get_traverser(from_left=True, depth_first=True, from_top=True): """ Returns a function that can be used to traverse a (sub)tree with the given node as root. While traversing the given root node is yielded at some point. The returned functions have this signature: .. code-block:: python def traverser(root: NodeBase, *filters: Filter) -> Iterator[NodeBase]: ... :param from_left: The traverser yields sibling nodes from left to right if ``True``, or starting from the right if ``False``. :param depth_first: The child nodes resp. the parent node are yielded before the siblings of a node by a traverser if ``True``. Siblings are favored if ``False``. :param from_top: The traverser starts yielding nodes with the lowest depth if ``True``. When ``False``, again, the opposite is in effect. """ result = TRAVERSERS.get((from_left, depth_first, from_top)) if result is None: raise NotImplementedError return result
def _is_node_of_type(node: NodeBase, type_name: str) -> bool: if type_name not in { "CommentNode", "ProcessingInstructionNode", "TagNode", "TextNode", }: raise ValueError return ( node.__class__.__module__ == f"{__package__}.nodes" and node.__class__.__qualname__ == type_name )
[docs]def last(iterable: Iterable) -> Optional[Any]: """ Returns the last item of the given :term:`iterable` or ``None`` if it's empty. Note that the whole :term:`iterator` is consumed when such is given. """ if isinstance(iterable, Iterator): result = None for result in iterable: pass return result elif isinstance(iterable, Sequence): return iterable[-1] if len(iterable) else None else: raise TypeError
# REMOVE eventually def _random_unused_prefix(namespaces: "etree._NSMap") -> str: for prefix in ascii_lowercase: if prefix not in namespaces: return prefix raise RuntimeError( "You really are using all latin letters as prefix in a document. " "Fair enough, please open a bug report." )
[docs]def register_namespace(prefix: str, namespace: str): """ Registers a namespace prefix that newly created :class:`TagNode` instances in that namespace will use in serializations. The registry is global, and any existing mapping for either the given prefix or the namespace URI will be removed. It has however no effect on the serialization of existing nodes, see :meth:`Document.cleanup_namespace` for that. :param prefix: The prefix to register. :param namespace: The targeted namespace. """ warn( "This function will be replaced with a different mechanism in a future version " "without a backward-compatible facilitation through this function.", category=PendingDeprecationWarning, ) etree.register_namespace(prefix, namespace)
def sort_nodes_in_document_order(nodes: Iterable[NodeBase]) -> Iterator[NodeBase]: sorter = _Nodes_Sorter() for node in nodes: if not _is_node_of_type(node, "TagNode"): raise NotImplementedError( "Support for sorting other node types than TagNodes isn't scheduled" "yet." ) node = cast("TagNode", node) if node.parent is None: path = [] else: path = [int(x[2:-1]) for x in node.location_path.split("/")[1:]] sorter.add(path, node) yield from sorter.emit() # tree traversers def traverse_df_ltr_btt(root: "NodeBase", *filters: Filter) -> Iterator["NodeBase"]: def yield_children(node): for child in tuple(node.iterate_children(*filters)): yield from yield_children(child) yield node yield from yield_children(root) def traverse_df_ltr_ttb(root: "NodeBase", *filters: Filter) -> Iterator["NodeBase"]: yield root yield from root.iterate_descendants(*filters) TRAVERSERS = { (True, True, True): traverse_df_ltr_ttb, (True, True, False): traverse_df_ltr_btt, } __all__ = ( first.__name__, get_traverser.__name__, last.__name__, register_namespace.__name__, sort_nodes_in_document_order.__name__, )