# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Utilities for working with ``PyTree``\s.

The :mod:`optree.pytree` namespace contains aliases of ``optree.tree_*`` utilities.

>>> import optree.pytree as pytree
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> leaves, treespec = pytree.flatten(tree)
>>> leaves, treespec  # doctest: +IGNORE_WHITESPACE
(
    [1, 2, 3, 4, 5],
    PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> tree == pytree.unflatten(treespec, leaves)
True

.. versionadded:: 0.14.1
"""

from __future__ import annotations

import functools as _functools
import inspect as _inspect
import sys as _sys
from builtins import all as _all
from types import ModuleType as _ModuleType
from typing import TYPE_CHECKING as _TYPE_CHECKING

import optree.dataclasses as dataclasses
import optree.functools as functools
from optree.accessors import PyTreeEntry
from optree.ops import tree_accessors as accessors
from optree.ops import tree_all as all  # pylint: disable=redefined-builtin
from optree.ops import tree_any as any  # pylint: disable=redefined-builtin
from optree.ops import tree_broadcast_common as broadcast_common
from optree.ops import tree_broadcast_map as broadcast_map
from optree.ops import tree_broadcast_map_with_accessor as broadcast_map_with_accessor
from optree.ops import tree_broadcast_map_with_path as broadcast_map_with_path
from optree.ops import tree_broadcast_prefix as broadcast_prefix
from optree.ops import tree_flatten as flatten
from optree.ops import tree_flatten_one_level as flatten_one_level
from optree.ops import tree_flatten_with_accessor as flatten_with_accessor
from optree.ops import tree_flatten_with_path as flatten_with_path
from optree.ops import tree_is_leaf as is_leaf
from optree.ops import tree_iter as iter  # pylint: disable=redefined-builtin
from optree.ops import tree_leaves as leaves
from optree.ops import tree_map as map  # pylint: disable=redefined-builtin
from optree.ops import tree_map_ as map_
from optree.ops import tree_map_with_accessor as map_with_accessor
from optree.ops import tree_map_with_accessor_ as map_with_accessor_
from optree.ops import tree_map_with_path as map_with_path
from optree.ops import tree_map_with_path_ as map_with_path_
from optree.ops import tree_max as max  # pylint: disable=redefined-builtin
from optree.ops import tree_min as min  # pylint: disable=redefined-builtin
from optree.ops import tree_partition as partition
from optree.ops import tree_paths as paths
from optree.ops import tree_reduce as reduce
from optree.ops import tree_replace_nones as replace_nones
from optree.ops import tree_structure as structure
from optree.ops import tree_sum as sum  # pylint: disable=redefined-builtin
from optree.ops import tree_transpose as transpose
from optree.ops import tree_transpose_map as transpose_map
from optree.ops import tree_transpose_map_with_accessor as transpose_map_with_accessor
from optree.ops import tree_transpose_map_with_path as transpose_map_with_path
from optree.ops import tree_unflatten as unflatten
from optree.registry import dict_insertion_ordered
from optree.registry import register_pytree_node as register_node
from optree.registry import register_pytree_node_class as register_node_class
from optree.registry import unregister_pytree_node as unregister_node
from optree.typing import PyTreeKind, PyTreeSpec
from optree.version import __version__ as __version__  # pylint: disable=useless-import-alias


__all__ = [
    'reexport',
    'PyTreeSpec',
    'PyTreeKind',
    'PyTreeEntry',
    'flatten',
    'flatten_with_path',
    'flatten_with_accessor',
    'unflatten',
    'iter',
    'leaves',
    'structure',
    'paths',
    'accessors',
    'is_leaf',
    'map',
    'map_',
    'map_with_path',
    'map_with_path_',
    'map_with_accessor',
    'map_with_accessor_',
    'replace_nones',
    'partition',
    'transpose',
    'transpose_map',
    'transpose_map_with_path',
    'transpose_map_with_accessor',
    'broadcast_prefix',
    'broadcast_common',
    'broadcast_map',
    'broadcast_map_with_path',
    'broadcast_map_with_accessor',
    'reduce',
    'sum',
    'max',
    'min',
    'all',
    'any',
    'flatten_one_level',
    'register_node',
    'register_node_class',
    'unregister_node',
    'dict_insertion_ordered',
]


if _TYPE_CHECKING:
    from collections.abc import Callable, Iterable
    from typing import Any, TypeVar  # pylint: disable=ungrouped-imports
    from typing_extensions import ParamSpec  # Python 3.10+

    _P = ParamSpec('_P')
    _T = TypeVar('_T')


class ReexportedModule(_ModuleType):
    """A module that re-exports APIs from another module."""

    __doc__: str

    def __init__(
        self,
        name: str,
        *,
        namespace: str,
        original: _ModuleType,
        doc: str | None = None,
        __all__: Iterable[str] | None = None,
        __dir__: Iterable[str] | None = None,
        extra_members: dict[str, Any] | None = None,
    ) -> None:
        doc = doc or (
            f'Re-exports :mod:`{original.__name__}` as :mod:`{name}` '
            f'with namespace :const:`{namespace!r}`.'
        )
        super().__init__(name, doc)

        if __all__ is None:  # pragma: no branch
            __all__ = {n for n in original.__all__ if n != 'reexport'}
        __all__ = set(__all__)
        if __dir__ is None:  # pragma: no branch
            __dir__ = {n for n in original.__dir__() if not n.startswith('_') and n != 'reexport'}
        __dir__ = set(__dir__).intersection(__all__)

        if extra_members:
            for key, value in extra_members.items():
                setattr(self, key, value)
            __dir__.update(extra_members)

        self.__namespace = namespace
        self.__original = original
        self.__all_set = __all__
        self.__all = sorted(__all__)
        self.__dir = sorted(__dir__)

    @property
    def __all__(self) -> list[str]:
        """Return the list of attributes available in this module."""
        return self.__all

    def __dir__(self) -> list[str]:
        """Return the list of attributes available in this module."""
        return self.__dir.copy()

    def __getattr__(self, name: str, /) -> Any:
        """Get an attribute from the re-exported module."""
        if name in self.__all_set:
            attr = getattr(self.__original, name)
            if _inspect.isfunction(attr):
                attr = self.__reexport__(attr)
            setattr(self, name, attr)
            return attr
        raise AttributeError(f'module {self.__name__!r} has no attribute {name!r}')

    def __reexport__(self, func: Callable[_P, _T], /) -> Callable[_P, _T]:
        """Re-export a function with the default namespace."""
        sig = _inspect.signature(func)
        if 'namespace' not in sig.parameters:

            @_functools.wraps(func)
            def wrapped(*args: _P.args, **kwargs: _P.kwargs) -> _T:
                return func(*args, **kwargs)
        else:

            @_functools.wraps(func)
            def wrapped(  # type: ignore[valid-type]
                *args: _P.args,
                namespace: str = self.__namespace,
                **kwargs: _P.kwargs,
            ) -> _T:
                return func(*args, namespace=namespace, **kwargs)  # type: ignore[arg-type]

            if func.__doc__:  # pragma: no branch
                wrapped.__doc__ = func.__doc__.replace(
                    "(default: :const:`''`, i.e., the global namespace)",
                    f'(default: :const:`{self.__namespace!r}`)',
                )
            wrapped.__signature__ = sig.replace(  # type: ignore[attr-defined]
                parameters=[
                    p if p.name != 'namespace' else p.replace(default=self.__namespace)
                    for p in sig.parameters.values()
                ],
            )

        if callable(getattr(func, 'get', None)):
            wrapped.get = self.__reexport__(func.get)  # type: ignore[attr-defined]

        return wrapped


if _TYPE_CHECKING:
    # pylint: disable-next=missing-class-docstring,too-few-public-methods
    class ReexportedPyTreeModule(ReexportedModule):
        __version__: str
        functools: _ModuleType
        dataclasses: _ModuleType

        PyTreeSpec: type[PyTreeSpec] = PyTreeSpec
        PyTreeKind: type[PyTreeKind] = PyTreeKind
        PyTreeEntry: type[PyTreeEntry] = PyTreeEntry
        flatten = staticmethod(flatten)
        flatten_with_path = staticmethod(flatten_with_path)
        flatten_with_accessor = staticmethod(flatten_with_accessor)
        unflatten = staticmethod(unflatten)
        iter = staticmethod(iter)
        leaves = staticmethod(leaves)
        structure = staticmethod(structure)
        paths = staticmethod(paths)
        accessors = staticmethod(accessors)
        is_leaf = staticmethod(is_leaf)
        map = staticmethod(map)
        map_ = staticmethod(map_)
        map_with_path = staticmethod(map_with_path)
        map_with_path_ = staticmethod(map_with_path_)
        map_with_accessor = staticmethod(map_with_accessor)
        map_with_accessor_ = staticmethod(map_with_accessor_)
        replace_nones = staticmethod(replace_nones)
        partition = staticmethod(partition)
        transpose = staticmethod(transpose)
        transpose_map = staticmethod(transpose_map)
        transpose_map_with_path = staticmethod(transpose_map_with_path)
        transpose_map_with_accessor = staticmethod(transpose_map_with_accessor)
        broadcast_prefix = staticmethod(broadcast_prefix)
        broadcast_common = staticmethod(broadcast_common)
        broadcast_map = staticmethod(broadcast_map)
        broadcast_map_with_path = staticmethod(broadcast_map_with_path)
        broadcast_map_with_accessor = staticmethod(broadcast_map_with_accessor)
        reduce = staticmethod(reduce)
        sum = staticmethod(sum)
        max = staticmethod(max)
        min = staticmethod(min)
        all = staticmethod(all)
        any = staticmethod(any)
        flatten_one_level = staticmethod(flatten_one_level)
        register_node = staticmethod(register_node)
        register_node_class = staticmethod(register_node_class)
        unregister_node = staticmethod(unregister_node)
        dict_insertion_ordered = staticmethod(dict_insertion_ordered)

    def reexport(*, namespace: str, module: str | None = None) -> ReexportedPyTreeModule:
        """Re-export a pytree utility module with the given namespace as default."""
        raise NotImplementedError('reexport() is not available in type checking mode')

else:

    def reexport(*, namespace: str, module: str | None = None) -> _ModuleType:  # type: ignore[misc]
        """Re-export a pytree utility module with the given namespace as default.

        >>> import optree
        >>> pytree = optree.pytree.reexport(namespace='my-pkg', module='my_pkg.pytree')
        >>> pytree.flatten({'a': 1, 'b': 2})
        ([1, 2], PyTreeSpec({'a': *, 'b': *}))

        This function is useful for downstream libraries that want to re-export the pytree utilities
        with their own namespace::

            # foo/__init__.py
            import optree
            pytree = optree.pytree.reexport(namespace='foo')

            # foo/bar.py
            from foo import pytree

            @pytree.dataclasses.dataclass
            class Bar:
                a: int
                b: float

            print(pytree.flatten({'a': 1, 'b': 2, 'c': Bar(3, 4.0)}))
            # Output:
            #   ([1, 2, 3, 4.0], PyTreeSpec({'a': *, 'b': *, 'c': CustomTreeNode(Bar[()], [*, *])}, namespace='foo'))

        Args:
            namespace (str): The namespace to re-export from.
            module (str, optional): The name of the module to re-export.
                If not provided, defaults to ``<caller_module>.pytree``. The caller module is determined
                by inspecting the stack frame.

        Returns:
            The re-exported module.
        """
        # pylint: disable-next=import-outside-toplevel
        from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE

        if namespace is GLOBAL_NAMESPACE:
            namespace = ''
        elif not isinstance(namespace, str):
            raise TypeError(f'The namespace must be a string, got {namespace!r}.')

        if module is None:
            try:
                # pylint: disable-next=protected-access
                caller_module = _sys._getframemodulename(1) or '__main__'  # type: ignore[attr-defined]
            except AttributeError:  # pragma: no cover
                try:
                    # pylint: disable-next=protected-access
                    caller_module = _sys._getframe(1).f_globals.get('__name__', '__main__')
                except (AttributeError, ValueError):
                    caller_module = '__main__'
            module = f'{caller_module}.pytree'
        if not module or not _all(part.isidentifier() for part in module.split('.')):
            raise ValueError(f'invalid module name: {module!r}')

        for module_name in (module, f'{module}.dataclasses', f'{module}.functools'):
            if module_name in _sys.modules:
                raise ValueError(f'module {module_name!r} already exists')

        reexported_dataclasses = ReexportedModule(
            f'{module}.dataclasses',
            namespace=namespace,
            original=dataclasses,
        )
        reexported_functools = ReexportedModule(
            f'{module}.functools',
            namespace=namespace,
            original=functools,
        )
        mod: ReexportedPyTreeModule = ReexportedModule(  # type: ignore[assignment]
            module,
            namespace=namespace,
            original=_sys.modules[__name__],
            extra_members={
                '__version__': __version__,
                'dataclasses': reexported_dataclasses,
                'functools': reexported_functools,
            },
        )
        _sys.modules[module] = mod
        _sys.modules[f'{module}.dataclasses'] = reexported_dataclasses
        _sys.modules[f'{module}.functools'] = reexported_functools
        return mod
