Source code for _delb.utils

# Copyright (C) 2018-'22  Frank Sachsenheim
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

from __future__ import annotations

import re
import sys
from collections import defaultdict
from collections.abc import Callable, Iterable, Iterator, Sequence
from copy import copy
from functools import partial
from string import ascii_lowercase
from typing import TYPE_CHECKING, cast, Any, Optional
from warnings import warn

if TYPE_CHECKING:
    from lxml import etree

    from _delb.nodes import NodeBase, TagNode
    from _delb.typing import Filter

_crunch_whitespace = partial(re.compile(r"\s+").sub, " ")


class _NodesSorter:
    def __init__(self):
        self.__node = None
        self.__items = defaultdict(_NodesSorter)

    def add(self, path: Sequence[int], node: TagNode):
        assert _is_node_of_type(node, "TagNode")
        if path:
            self.__items[path[0]].add(path[1:], node)
        else:
            self.__node = node

    def emit(self) -> Iterator[NodeBase]:
        if self.__node is not None:
            yield self.__node
        for index in sorted(self.__items):
            yield from self.__items[index].emit()


class _StringMixin:  # pragma: no cover
    # copied from CPython 3.10.0's stdlib collections.UserString and adjusted

    __slots__ = ()

    def __str__(self):
        return str(self._data)

    def __int__(self):
        return int(self._data)

    def __float__(self):
        return float(self._data)

    def __complex__(self):
        return complex(self._data)

    def __hash__(self):
        return hash(self._data)

    def __eq__(self, string):
        return self._data == string

    def __lt__(self, string):
        return self._data < string

    def __le__(self, string):
        return self._data <= string

    def __gt__(self, string):
        return self._data > string

    def __ge__(self, string):
        return self._data >= string

    def __contains__(self, char):
        return char in self._data

    def __len__(self):
        return len(self._data)

    def __getitem__(self, index):
        return self._data[index]

    def __add__(self, other):
        if isinstance(other, str):
            return self._data + other
        return self._data + str(other)

    def __radd__(self, other):
        if isinstance(other, str):
            return other + self._data
        return str(other) + self._data

    def __mul__(self, n):
        return self._data * n

    __rmul__ = __mul__

    def __mod__(self, args):
        return self._data % args

    def __rmod__(self, template):
        return str(template) % self

    def capitalize(self):
        return self._data.capitalize()

    def casefold(self):
        return self._data.casefold()

    def center(self, width, *args):
        return self._data.center(width, *args)

    def count(self, sub, start=0, end=sys.maxsize):
        return self._data.count(sub, start, end)

    def removeprefix(self, prefix):
        return self._data.removeprefix(prefix)

    def removesuffix(self, suffix):
        return self._data.removesuffix(suffix)

    def encode(self, encoding="utf-8", errors="strict"):
        encoding = "utf-8" if encoding is None else encoding
        errors = "strict" if errors is None else errors
        return self._data.encode(encoding, errors)

    def endswith(self, suffix, start=0, end=sys.maxsize):
        return self._data.endswith(suffix, start, end)

    def expandtabs(self, tabsize=8):
        return self._data.expandtabs(tabsize)

    def find(self, sub, start=0, end=sys.maxsize):
        return self._data.find(sub, start, end)

    def format(self, *args, **kwds):
        return self._data.format(*args, **kwds)

    def format_map(self, mapping):
        return self._data.format_map(mapping)

    def index(self, sub, start=0, end=sys.maxsize):
        return self._data.index(sub, start, end)

    def isalpha(self):
        return self._data.isalpha()

    def isalnum(self):
        return self._data.isalnum()

    def isascii(self):
        return self._data.isascii()

    def isdecimal(self):
        return self._data.isdecimal()

    def isdigit(self):
        return self._data.isdigit()

    def isidentifier(self):
        return self._data.isidentifier()

    def islower(self):
        return self._data.islower()

    def isnumeric(self):
        return self._data.isnumeric()

    def isprintable(self):
        return self._data.isprintable()

    def isspace(self):
        return self._data.isspace()

    def istitle(self):
        return self._data.istitle()

    def isupper(self):
        return self._data.isupper()

    def join(self, seq):
        return self._data.join(seq)

    def ljust(self, width, *args):
        return self._data.ljust(width, *args)

    def lower(self):
        return self._data.lower()

    def lstrip(self, chars=None):
        return self._data.lstrip(chars)

    def partition(self, sep):
        return self._data.partition(sep)

    def replace(self, old, new, maxsplit=-1):
        return self._data.replace(old, new, maxsplit)

    def rfind(self, sub, start=0, end=sys.maxsize):
        return self._data.rfind(sub, start, end)

    def rindex(self, sub, start=0, end=sys.maxsize):
        return self._data.rindex(sub, start, end)

    def rjust(self, width, *args):
        return self._data.rjust(width, *args)

    def rpartition(self, sep):
        return self._data.rpartition(sep)

    def rstrip(self, chars=None):
        return self._data.rstrip(chars)

    def split(self, sep=None, maxsplit=-1):
        return self._data.split(sep, maxsplit)

    def rsplit(self, sep=None, maxsplit=-1):
        return self._data.rsplit(sep, maxsplit)

    def splitlines(self, keepends=False):
        return self._data.splitlines(keepends)

    def startswith(self, prefix, start=0, end=sys.maxsize):
        return self._data.startswith(prefix, start, end)

    def strip(self, chars=None):
        return self._data.strip(chars)

    def swapcase(self):
        return self._data.swapcase()

    def title(self):
        return self._data.title()

    def translate(self, *args):
        return self._data.translate(*args)

    def upper(self):
        return self._data.upper()

    def zfill(self, width):
        return self._data.zfill(width)


def _better_call(f: Callable | property) -> Callable:
    def decorator(d: Callable) -> Callable:
        def wrapper(*args, **kwargs):
            """:meta category: deprecated"""
            warn(
                f"{d.__name__} is deprecated, use {f.__name__} instead.",
                category=DeprecationWarning,
            )
            return f(*args, **kwargs)

        d.__doc__ = ":meta: private"
        return wrapper

    return decorator


def _better_yield(f: Callable) -> Callable:
    def decorator(d: Callable) -> Callable:
        def wrapper(*args, **kwargs):
            """:meta category: deprecated"""
            warn(
                f"{d.__name__} is deprecated, use {f.__name__} instead.",
                category=DeprecationWarning,
            )
            yield from f(*args, **kwargs)

        return wrapper

    return decorator


def _copy_root_siblings(source: etree._Element, target: etree._Element):
    stack = []
    current_element = source.getprevious()
    while current_element is not None:
        stack.append(current_element)
        current_element = current_element.getprevious()
    while stack:
        target.addprevious(copy(stack.pop()))

    stack = []
    current_element = source.getnext()
    while current_element is not None:
        stack.append(current_element)
        current_element = current_element.getnext()
    while stack:
        target.addnext(copy(stack.pop()))


[docs]def first(iterable: Iterable) -> Optional[Any]: """ Returns the first item of the given :term:`iterable` or :py:obj:`None` if it's empty. Note that the first item is consumed when the iterable is an :term:`iterator`. """ if isinstance(iterable, Iterator): try: return next(iterable) except StopIteration: return None elif isinstance(iterable, Sequence): return iterable[0] if len(iterable) else None else: raise TypeError
[docs]def get_traverser(from_left=True, depth_first=True, from_top=True): """ Returns a function that can be used to traverse a (sub)tree with the given node as root. While traversing the given root node is yielded at some point. The returned functions have this signature: .. code-block:: python def traverser(root: NodeBase, *filters: Filter) -> Iterator[NodeBase]: ... :param from_left: The traverser yields sibling nodes from left to right if :py:obj:`True`, or starting from the right if :py:obj:`False`. :param depth_first: The child nodes resp. the parent node are yielded before the siblings of a node by a traverser if :py:obj:`True`. Siblings are favored if :py:obj:`False`. :param from_top: The traverser starts yielding nodes with the lowest depth if :py:obj:`True`. When :py:obj:`False`, again, the opposite is in effect. """ result = TRAVERSERS.get((from_left, depth_first, from_top)) if result is None: raise NotImplementedError return result
def _is_node_of_type(node: NodeBase, type_name: str) -> bool: if type_name not in { "CommentNode", "ProcessingInstructionNode", "TagNode", "TextNode", }: raise ValueError return ( node.__class__.__module__ == f"{__package__}.nodes" and node.__class__.__qualname__ == type_name )
[docs]def last(iterable: Iterable) -> Optional[Any]: """ Returns the last item of the given :term:`iterable` or :py:obj:`None` if it's empty. Note that the whole :term:`iterator` is consumed when such is given. """ if isinstance(iterable, Iterator): result = None for result in iterable: pass return result elif isinstance(iterable, Sequence): return iterable[-1] if len(iterable) else None else: raise TypeError
# REMOVE eventually def _random_unused_prefix(namespaces: etree._NSMap) -> str: for prefix in ascii_lowercase: if prefix not in namespaces: return prefix raise RuntimeError( "You really are using all latin letters as prefix in a document. " "Fair enough, please open a bug report." ) def sort_nodes_in_document_order(nodes: Iterable[NodeBase]) -> Iterator[NodeBase]: sorter = _NodesSorter() for node in nodes: if not _is_node_of_type(node, "TagNode"): raise NotImplementedError( "Support for sorting other node types than TagNodes isn't scheduled" "yet." ) node = cast("TagNode", node) if node.parent is None: path = [] else: path = [int(x[2:-1]) for x in node.location_path.split("/")[1:]] sorter.add(path, node) yield from sorter.emit() # tree traversers def traverse_bf_ltr_ttb(root: NodeBase, *filters: Filter) -> Iterator[NodeBase]: if all(f(root) for f in filters): yield root queue = list(root.iterate_children()) while queue: node = queue.pop(0) if _is_node_of_type(node, "TagNode"): queue.extend(node.iterate_children()) if all(f(node) for f in filters): yield node def traverse_df_ltr_btt(root: NodeBase, *filters: Filter) -> Iterator[NodeBase]: def yield_children(node): for child in tuple(node.iterate_children(*filters)): yield from yield_children(child) yield node yield from yield_children(root) def traverse_df_ltr_ttb(root: NodeBase, *filters: Filter) -> Iterator[NodeBase]: yield root yield from root.iterate_descendants(*filters) TRAVERSERS = { (True, False, True): traverse_bf_ltr_ttb, (True, True, True): traverse_df_ltr_ttb, (True, True, False): traverse_df_ltr_btt, } __all__ = ( first.__name__, get_traverser.__name__, last.__name__, sort_nodes_in_document_order.__name__, )