import copy
import importlib
import os
import sys

from keras.src import backend as backend_module
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state


def in_tf_graph():
    if global_state.get_global_attribute("in_tf_graph_scope", False):
        return True

    if "tensorflow" in sys.modules:
        from keras.src.utils.module_utils import tensorflow as tf

        return not tf.executing_eagerly()
    return False


def convert_tf_tensor(outputs, dtype=None):
    if backend_module.backend() != "tensorflow" and not in_tf_graph():
        outputs = backend_module.convert_to_tensor(outputs, dtype=dtype)
    return outputs


class TFGraphScope:
    def __init__(self):
        self._original_value = global_state.get_global_attribute(
            "in_tf_graph_scope", False
        )

    def __enter__(self):
        global_state.set_global_attribute("in_tf_graph_scope", True)

    def __exit__(self, *args, **kwargs):
        global_state.set_global_attribute(
            "in_tf_graph_scope", self._original_value
        )


class DynamicBackend:
    """A class that can be used to switch from one backend to another.

    Example:

    ```python
    backend = DynamicBackend("tensorflow")
    y = backend.square(tf.constant(...))
    backend.set_backend("jax")
    y = backend.square(jax.numpy.array(...))
    ```

    Args:
        backend: Initial backend to use (string).
    """

    def __init__(self, backend=None):
        self._backend = backend or backend_module.backend()

    def set_backend(self, backend):
        if backend not in ("tensorflow", "jax", "torch", "numpy", "openvino"):
            raise ValueError(
                "Available backends are ('tensorflow', 'jax', 'torch', "
                f"'numpy' and 'openvino'). Received: backend={backend}"
            )
        self._backend = backend

    def reset(self):
        self._backend = backend_module.backend()

    @property
    def name(self):
        return self._backend

    def __getattr__(self, name):
        if self._backend == "tensorflow":
            module = importlib.import_module("keras.src.backend.tensorflow")
            return getattr(module, name)
        if self._backend == "jax":
            module = importlib.import_module("keras.src.backend.jax")
            return getattr(module, name)
        if self._backend == "torch":
            module = importlib.import_module("keras.src.backend.torch")
            return getattr(module, name)
        if self._backend == "numpy":
            if backend_module.backend() == "numpy":
                return getattr(backend_module, name)
            else:
                raise NotImplementedError(
                    "Currently, we cannot dynamically import the numpy backend "
                    "because it would disrupt the namespace of the import."
                )
        if self._backend == "openvino":
            module = importlib.import_module("keras.src.backend.openvino")
            return getattr(module, name)


@keras_export("keras.config.set_backend")
def set_backend(backend):
    """Reload the backend (and the Keras package).

    Example:

    ```python
    keras.config.set_backend("jax")
    ```

    ⚠️ WARNING ⚠️: Using this function is dangerous and should be done
    carefully. Changing the backend will **NOT** convert
    the type of any already-instantiated objects.
    Thus, any layers / tensors / etc. already created will no
    longer be usable without errors. It is strongly recommended **not**
    to keep around **any** Keras-originated objects instances created
    before calling `set_backend()`.

    This includes any function or class instance that uses any Keras
    functionality. All such code needs to be re-executed after calling
    `set_backend()`.
    """
    os.environ["KERAS_BACKEND"] = backend
    # Clear module cache.
    loaded_modules = [
        key for key in sys.modules.keys() if key.startswith("keras")
    ]
    for key in loaded_modules:
        del sys.modules[key]
    # Reimport Keras with the new backend (set via KERAS_BACKEND).
    import keras

    # Finally: refresh all imported Keras submodules.
    globs = copy.copy(globals())
    for key, value in globs.items():
        if value.__class__ == keras.__class__:
            if str(value).startswith("<module 'keras."):
                module_name = str(value)
                module_name = module_name[module_name.find("'") + 1 :]
                module_name = module_name[: module_name.find("'")]
                globals()[key] = importlib.import_module(module_name)
