import copy
import inspect
import itertools
import string
import warnings

from keras.src import tree
from keras.src.backend.common.stateless_scope import StatelessScope
from keras.src.utils.module_utils import tensorflow as tf


class JaxExportArchive:
    def __init__(self):
        self._backend_variables = []
        self._backend_trainable_variables = []
        self._backend_non_trainable_variables = []

    def _track_layer(self, layer):
        # Variables in the lists below are actually part of the trackables
        # that get saved, because the lists are created in __init__.
        trainable_variables = layer.trainable_variables
        non_trainable_variables = layer.non_trainable_variables

        self._tf_trackable.trainable_variables += tree.map_structure(
            self._convert_to_tf_variable, trainable_variables
        )
        self._tf_trackable.non_trainable_variables += tree.map_structure(
            self._convert_to_tf_variable, non_trainable_variables
        )
        self._tf_trackable.variables = (
            self._tf_trackable.trainable_variables
            + self._tf_trackable.non_trainable_variables
        )

        self._backend_trainable_variables += trainable_variables
        self._backend_non_trainable_variables += non_trainable_variables
        self._backend_variables = (
            self._backend_trainable_variables
            + self._backend_non_trainable_variables
        )

    def add_endpoint(self, name, fn, input_signature=None, **kwargs):
        jax2tf_kwargs = kwargs.pop("jax2tf_kwargs", None)
        # Use `copy.copy()` to avoid modification issues.
        jax2tf_kwargs = copy.copy(jax2tf_kwargs) or {}
        is_static = bool(kwargs.pop("is_static", False))

        # Configure `jax2tf_kwargs`
        if "native_serialization" not in jax2tf_kwargs:
            jax2tf_kwargs["native_serialization"] = (
                self._check_device_compatible()
            )
        if "polymorphic_shapes" not in jax2tf_kwargs:
            jax2tf_kwargs["polymorphic_shapes"] = self._to_polymorphic_shape(
                input_signature
            )

        # Note: we truncate the number of parameters to what is specified by
        # `input_signature`.
        fn_signature = inspect.signature(fn)
        fn_parameters = list(fn_signature.parameters.values())

        if is_static:
            from jax.experimental import jax2tf

            jax_fn = jax2tf.convert(fn, **jax2tf_kwargs)
            jax_fn.__signature__ = inspect.Signature(
                parameters=fn_parameters[0 : len(input_signature)],
                return_annotation=fn_signature.return_annotation,
            )

            decorated_fn = tf.function(
                jax_fn,
                input_signature=input_signature,
                autograph=False,
            )
        else:
            # 1. Create a stateless wrapper for `fn`
            # 2. jax2tf the stateless wrapper
            # 3. Create a stateful function that binds the variables with
            #    the jax2tf converted stateless wrapper
            # 4. Make the signature of the stateful function the same as the
            #    original function
            # 5. Wrap in a `tf.function`
            def stateless_fn(variables, *args, **kwargs):
                state_mapping = zip(self._backend_variables, variables)
                with StatelessScope(state_mapping=state_mapping) as scope:
                    output = fn(*args, **kwargs)

                # Gather updated non-trainable variables
                non_trainable_variables = []
                for var in self._backend_non_trainable_variables:
                    new_value = scope.get_current_value(var)
                    non_trainable_variables.append(new_value)
                return output, non_trainable_variables

            jax2tf_stateless_fn = self._convert_jax2tf_function(
                stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs
            )

            def stateful_fn(*args, **kwargs):
                output, non_trainable_variables = jax2tf_stateless_fn(
                    # Change the trackable `ListWrapper` to a plain `list`
                    list(self._tf_trackable.variables),
                    *args,
                    **kwargs,
                )
                for var, new_value in zip(
                    self._tf_trackable.non_trainable_variables,
                    non_trainable_variables,
                ):
                    var.assign(tf.cast(new_value, var.dtype))
                return output

            stateful_fn.__signature__ = inspect.Signature(
                parameters=fn_parameters[0 : len(input_signature)],
                return_annotation=fn_signature.return_annotation,
            )

            decorated_fn = tf.function(
                stateful_fn,
                input_signature=input_signature,
                autograph=False,
            )
        return decorated_fn

    def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None):
        from jax.experimental import jax2tf

        variables_shapes = self._to_polymorphic_shape(
            self._backend_variables, allow_none=False
        )
        input_shapes = list(jax2tf_kwargs["polymorphic_shapes"])
        jax2tf_kwargs["polymorphic_shapes"] = [variables_shapes] + input_shapes
        return jax2tf.convert(fn, **jax2tf_kwargs)

    def _to_polymorphic_shape(self, struct, allow_none=True):
        if allow_none:
            # Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz
            # for unknown non-batch dims. Defined here to be scope per endpoint.
            dim_names = itertools.chain(
                string.ascii_lowercase,
                itertools.starmap(
                    lambda a, b: a + b,
                    itertools.product(string.ascii_lowercase, repeat=2),
                ),
            )

        def convert_shape(x):
            poly_shape = []
            for index, dim in enumerate(list(x.shape)):
                if dim is not None:
                    poly_shape.append(str(dim))
                elif not allow_none:
                    raise ValueError(
                        f"Illegal None dimension in {x} with shape {x.shape}"
                    )
                elif index == 0:
                    poly_shape.append("batch")
                else:
                    poly_shape.append(next(dim_names))
            return "(" + ", ".join(poly_shape) + ")"

        return tree.map_structure(convert_shape, struct)

    def _check_device_compatible(self):
        from jax import default_backend as jax_device

        if (
            jax_device() == "gpu"
            and len(tf.config.list_physical_devices("GPU")) == 0
        ):
            warnings.warn(
                "JAX backend is using GPU for export, but installed "
                "TF package cannot access GPU, so reloading the model with "
                "the TF runtime in the same environment will not work. "
                "To use JAX-native serialization for high-performance export "
                "and serving, please install `tensorflow-gpu` and ensure "
                "CUDA version compatibility between your JAX and TF "
                "installations."
            )
            return False
        else:
            return True
