import inspect

import numpy as np

from keras.src import backend
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.backend.common.variables import is_float_dtype
from keras.src.backend.common.variables import standardize_dtype
from keras.src.layers.layer import Layer
from keras.src.saving import serialization_lib
from keras.src.utils import jax_utils
from keras.src.utils import tracking
from keras.src.utils.module_utils import jax


@keras_export("keras.layers.JaxLayer")
class JaxLayer(Layer):
    """Keras Layer that wraps a JAX model.

    This layer enables the use of JAX components within Keras when using JAX as
    the backend for Keras.

    ## Model function

    This layer accepts JAX models in the form of a function, `call_fn`, which
    must take the following arguments with these exact names:

    - `params`: trainable parameters of the model.
    - `state` (*optional*): non-trainable state of the model. Can be omitted if
        the model has no non-trainable state.
    - `rng` (*optional*): a `jax.random.PRNGKey` instance. Can be omitted if the
        model does not need RNGs, neither during training nor during inference.
    - `inputs`: inputs to the model, a JAX array or a `PyTree` of arrays.
    - `training` (*optional*): an argument specifying if we're in training mode
        or inference mode, `True` is passed in training mode. Can be omitted if
        the model behaves the same in training mode and inference mode.

    The `inputs` argument is mandatory. Inputs to the model must be provided via
    a single argument. If the JAX model takes multiple inputs as separate
    arguments, they must be combined into a single structure, for instance in a
    `tuple` or a `dict`.

    ## Model weights initialization

    The initialization of the `params` and `state` of the model can be handled
    by this layer, in which case the `init_fn` argument must be provided. This
    allows the model to be initialized dynamically with the right shape.
    Alternatively, and if the shape is known, the `params` argument and
    optionally the `state` argument can be used to create an already initialized
    model.

    The `init_fn` function, if provided, must take the following arguments with
    these exact names:

    - `rng`: a `jax.random.PRNGKey` instance.
    - `inputs`: a JAX array or a `PyTree` of arrays with placeholder values to
        provide the shape of the inputs.
    - `training` (*optional*): an argument specifying if we're in training mode
        or inference mode. `True` is always passed to `init_fn`. Can be omitted
        regardless of whether `call_fn` has a `training` argument.

    ## Models with non-trainable state

    For JAX models that have non-trainable state:

    - `call_fn` must have a `state` argument
    - `call_fn` must return a `tuple` containing the outputs of the model and
        the new non-trainable state of the model
    - `init_fn` must return a `tuple` containing the initial trainable params of
        the model and the initial non-trainable state of the model.

    This code shows a possible combination of `call_fn` and `init_fn` signatures
    for a model with non-trainable state. In this example, the model has a
    `training` argument and an `rng` argument in `call_fn`.

    ```python
    def stateful_call(params, state, rng, inputs, training):
        outputs = ...
        new_state = ...
        return outputs, new_state

    def stateful_init(rng, inputs):
        initial_params = ...
        initial_state = ...
        return initial_params, initial_state
    ```

    ## Models without non-trainable state

    For JAX models with no non-trainable state:

    - `call_fn` must not have a `state` argument
    - `call_fn` must return only the outputs of the model
    - `init_fn` must return only the initial trainable params of the model.

    This code shows a possible combination of `call_fn` and `init_fn` signatures
    for a model without non-trainable state. In this example, the model does not
    have a `training` argument and does not have an `rng` argument in `call_fn`.

    ```python
    def stateless_call(params, inputs):
        outputs = ...
        return outputs

    def stateless_init(rng, inputs):
        initial_params = ...
        return initial_params
    ```

    ## Conforming to the required signature

    If a model has a different signature than the one required by `JaxLayer`,
    one can easily write a wrapper method to adapt the arguments. This example
    shows a model that has multiple inputs as separate arguments, expects
    multiple RNGs in a `dict`, and has a `deterministic` argument with the
    opposite meaning of `training`. To conform, the inputs are combined in a
    single structure using a `tuple`, the RNG is split and used the populate the
    expected `dict`, and the Boolean flag is negated:

    ```python
    def my_model_fn(params, rngs, input1, input2, deterministic):
        ...
        if not deterministic:
            dropout_rng = rngs["dropout"]
            keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape)
            x = jax.numpy.where(keep, x / dropout_rate, 0)
            ...
        ...
        return outputs

    def my_model_wrapper_fn(params, rng, inputs, training):
        input1, input2 = inputs
        rng1, rng2 = jax.random.split(rng)
        rngs = {"dropout": rng1, "preprocessing": rng2}
        deterministic = not training
        return my_model_fn(params, rngs, input1, input2, deterministic)

    keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params)
    ```

    ## Usage with Haiku modules

    `JaxLayer` enables the use of [Haiku](https://dm-haiku.readthedocs.io)
    components in the form of
    [`haiku.Module`](https://dm-haiku.readthedocs.io/en/latest/api.html#module).
    This is achieved by transforming the module per the Haiku pattern and then
    passing `module.apply` in the `call_fn` parameter and `module.init` in the
    `init_fn` parameter if needed.

    If the model has non-trainable state, it should be transformed with
    [`haiku.transform_with_state`](
      https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform_with_state).
    If the model has no non-trainable state, it should be transformed with
    [`haiku.transform`](
      https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform).
    Additionally, and optionally, if the module does not use RNGs in "apply", it
    can be transformed with
    [`haiku.without_apply_rng`](
      https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng).

    The following example shows how to create a `JaxLayer` from a Haiku module
    that uses random number generators via `hk.next_rng_key()` and takes a
    training positional argument:

    ```python
    class MyHaikuModule(hk.Module):
        def __call__(self, x, training):
            x = hk.Conv2D(32, (3, 3))(x)
            x = jax.nn.relu(x)
            x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), "VALID")(x)
            x = hk.Flatten()(x)
            x = hk.Linear(200)(x)
            if training:
                x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x)
            x = jax.nn.relu(x)
            x = hk.Linear(10)(x)
            x = jax.nn.softmax(x)
            return x

    def my_haiku_module_fn(inputs, training):
        module = MyHaikuModule()
        return module(inputs, training)

    transformed_module = hk.transform(my_haiku_module_fn)

    keras_layer = JaxLayer(
        call_fn=transformed_module.apply,
        init_fn=transformed_module.init,
    )
    ```

    Args:
        call_fn: The function to call the model. See description above for the
            list of arguments it takes and the outputs it returns.
        init_fn: the function to call to initialize the model. See description
            above for the list of arguments it takes and the outputs it returns.
            If `None`, then `params` and/or `state` must be provided.
      params: A `PyTree` containing all the model trainable parameters. This
            allows passing trained parameters or controlling the initialization.
            If both `params` and `state` are `None`, `init_fn` is called at
            build time to initialize the trainable parameters of the model.
      state: A `PyTree` containing all the model non-trainable state. This
            allows passing learned state or controlling the initialization. If
            both `params` and `state` are `None`, and `call_fn` takes a `state`
            argument, then `init_fn` is called at build time to initialize the
            non-trainable state of the model.
      seed: Seed for random number generator. Optional.
      dtype: The dtype of the layer's computations and weights. Can also be a
            `keras.DTypePolicy`. Optional. Defaults to the default policy.
    """

    def __init__(
        self,
        call_fn,
        init_fn=None,
        params=None,
        state=None,
        seed=None,
        **kwargs,
    ):
        if backend.backend() != "jax":
            raise ValueError(
                "JaxLayer is only supported with the JAX backend. Current "
                f"backend: {backend.backend()}"
            )

        if init_fn is None and params is None and state is None:
            raise ValueError(
                "`init_fn`, `params` and `state` cannot all be `None`."
            )

        super().__init__(**kwargs)
        self.call_fn = call_fn
        self.init_fn = init_fn
        self.seed_generator = backend.random.SeedGenerator(seed)
        self.tracked_params = self._create_variables(params, trainable=True)
        self.tracked_state = self._create_variables(state, trainable=False)
        if self.params is not None or self.state is not None:
            self._build_at_init()

        self.call_fn_arguments = self._validate_signature(
            call_fn,
            "call_fn",
            {"params", "state", "rng", "inputs", "training"},
            {"inputs"},
        )
        self.has_state = "state" in self.call_fn_arguments

        if init_fn:
            self.init_fn_arguments = self._validate_signature(
                init_fn, "init_fn", {"rng", "inputs", "training"}, {"inputs"}
            )

    def _validate_signature(self, fn, fn_name, allowed, required):
        fn_parameters = inspect.signature(fn).parameters
        for parameter_name in required:
            if parameter_name not in fn_parameters:
                raise ValueError(
                    f"Missing required argument in `{fn_name}`: "
                    f"`{parameter_name}`"
                )

        parameter_names = []
        for parameter in fn_parameters.values():
            if parameter.name not in allowed:
                raise ValueError(
                    f"Unsupported argument in `{fn_name}`: `{parameter.name}`, "
                    f"supported arguments are `{'`, `'.join(allowed)}`"
                )
            parameter_names.append(parameter.name)

        return parameter_names

    @tracking.no_automatic_dependency_tracking
    def _create_variables(self, values, trainable):
        """Create a structure of variables from a structure of JAX arrays.

        `values` is traversed via JAX's `tree_map`. When a leaf is a JAX array
        or a tensor-like object, a corresponding variable is created with it as
        the initial value. The resulting structure of variables is assigned to
        `self.params` or `self.state` depending on `trainable`. Then, a
        flattened version of the variables is returned for tracking.
        `self.params` or `self.state` are intentionally not tracked because
        structures like `TrackedList` interfere with `jax.tree_utils`.
        Note that leaf objects that are not JAX arrays and not tensor-like are
        left intact as they are assumed to be configuration used by the model.

        Args:
            values: the structure of values to traverse.
            trainable: whether to create trainable variables.

        Returns:
            flat list of variables initialized with `values` for tracking.
        """

        def create_variable(value):
            if backend.is_tensor(value) or isinstance(
                value, (np.ndarray, np.generic)
            ):
                dtype = value.dtype
                if is_float_dtype(dtype):
                    dtype = None  # Use the layer dtype policy
                return self.add_weight(
                    value.shape,
                    initializer=value,
                    dtype=dtype,
                    trainable=trainable,
                )
            elif isinstance(value, (bool, int, float)):
                dtype = standardize_dtype(type(value))
                if is_float_dtype(dtype):
                    dtype = None  # Use the layer dtype policy
                return self.add_weight(
                    (),
                    initializer=backend.convert_to_tensor(value),
                    dtype=dtype,
                    trainable=trainable,
                )
            else:
                return value

        # Use JAX's tree_map as it understands registered classes.
        variables = jax.tree_util.tree_map(create_variable, values)

        if trainable:
            self.params = variables
        else:
            self.state = variables

        flat_variables, _ = jax.tree_util.tree_flatten(variables)
        return flat_variables

    def _get_init_rng(self):
        """
        Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `init_fn`.

        By default, this returns a single `PRNGKey` retrieved by calling
        `self.seed_generator.next()`. Override this to return a different
        structure.

        Returns:
            a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
            the `rng` argument of `init_fn`.
        """
        return self.seed_generator.next()

    def _get_call_rng(self, training):
        """
        Returns a JAX `PRNGKey` or structure of `PRNGKey`s to pass to `call_fn`.

        By default, this returns a single `PRNGKey` retrieved by calling
        `self.seed_generator.next()` when `training` is `True`, and `None` when
        `training` is `False`. Override this to return a different structure or
        to pass RNGs in inference mode too.

        Returns:
            a JAX `PRNGKey` or structure of `PRNGKey`s that will be passed as
            the `rng` argument of `call_fn`.
        """
        if training:
            return self.seed_generator.next()
        else:
            return None

    def build(self, input_shape):
        if self.params is not None or self.state is not None:
            return

        if jax_utils.is_in_jax_tracing_scope():
            # This exception is not actually shown, it is caught and a detailed
            # warning about calling 'build' is printed.
            raise ValueError("'JaxLayer' cannot be built in tracing scope")

        # Initialize `params` and `state` if needed by calling `init_fn`.
        def create_input(shape):
            shape = [d if d is not None else 1 for d in shape]
            return jax.numpy.ones(shape)

        init_inputs = tree.map_shape_structure(create_input, input_shape)
        init_args = []
        for argument_name in self.init_fn_arguments:
            if argument_name == "rng":
                init_args.append(self._get_init_rng())
            elif argument_name == "inputs":
                init_args.append(init_inputs)
            elif argument_name == "training":
                init_args.append(True)

        init_result = self.init_fn(*init_args)
        if self.has_state:
            init_params, init_state = init_result
        else:
            init_params, init_state = init_result, None

        self.tracked_params = self._create_variables(
            init_params, trainable=True
        )
        self.tracked_state = self._create_variables(init_state, trainable=False)

    def call(self, inputs, training=False):
        def unwrap_variable(variable):
            return None if variable is None else variable.value

        call_args = []
        for argument_name in self.call_fn_arguments:
            if argument_name == "params":
                call_args.append(
                    jax.tree_util.tree_map(unwrap_variable, self.params)
                )
            elif argument_name == "state":
                call_args.append(
                    jax.tree_util.tree_map(unwrap_variable, self.state)
                )
            elif argument_name == "rng":
                call_args.append(self._get_call_rng(training))
            elif argument_name == "inputs":
                call_args.append(inputs)
            elif argument_name == "training":
                call_args.append(training)

        def assign_state_to_variable(value, variable):
            # This exists only to make debugging this error case easier.
            if not hasattr(variable, "assign"):
                raise ValueError(
                    "Structure mismatch: the structure of the state returned "
                    "by `call` does not match the structure of the state at "
                    "initialization time."
                )
            variable.assign(value)

        if self.has_state:
            predictions, new_state = self.call_fn(*call_args)
            jax.tree_util.tree_map(
                assign_state_to_variable, new_state, self.state
            )
            return predictions
        else:
            return self.call_fn(*call_args)

    def get_config(self):
        config = {
            "call_fn": serialization_lib.serialize_keras_object(self.call_fn),
            "init_fn": serialization_lib.serialize_keras_object(self.init_fn),
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

    @classmethod
    def from_config(cls, config):
        call_fn = serialization_lib.deserialize_keras_object(config["call_fn"])
        init_fn = serialization_lib.deserialize_keras_object(config["init_fn"])
        config["call_fn"] = call_fn
        config["init_fn"] = init_fn
        return super().from_config(config)


@keras_export("keras.layers.FlaxLayer")
class FlaxLayer(JaxLayer):
    """Keras Layer that wraps a [Flax](https://flax.readthedocs.io) module.

    This layer enables the use of Flax components in the form of
    [`flax.linen.Module`](
        https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)
    instances within Keras when using JAX as the backend for Keras.

    The module method to use for the forward pass can be specified via the
    `method` argument and is `__call__` by default. This method must take the
    following arguments with these exact names:

    - `self` if the method is bound to the module, which is the case for the
        default of `__call__`, and `module` otherwise to pass the module.
    - `inputs`: the inputs to the model, a JAX array or a `PyTree` of arrays.
    - `training` *(optional)*: an argument specifying if we're in training mode
        or inference mode, `True` is passed in training mode.

    `FlaxLayer` handles the non-trainable state of your model and required RNGs
    automatically. Note that the `mutable` parameter of
    [`flax.linen.Module.apply()`](
        https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.apply)
    is set to `DenyList(["params"])`, therefore making the assumption that all
    the variables outside of the "params" collection are non-trainable weights.

    This example shows how to create a `FlaxLayer` from a Flax `Module` with
    the default `__call__` method and no training argument:

    ```python
    class MyFlaxModule(flax.linen.Module):
        @flax.linen.compact
        def __call__(self, inputs):
            x = inputs
            x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)
            x = flax.linen.relu(x)
            x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
            x = x.reshape((x.shape[0], -1))  # flatten
            x = flax.linen.Dense(features=200)(x)
            x = flax.linen.relu(x)
            x = flax.linen.Dense(features=10)(x)
            x = flax.linen.softmax(x)
            return x

    flax_module = MyFlaxModule()
    keras_layer = FlaxLayer(flax_module)
    ```

    This example shows how to wrap the module method to conform to the required
    signature. This allows having multiple input arguments and a training
    argument that has a different name and values. This additionally shows how
    to use a function that is not bound to the module.

    ```python
    class MyFlaxModule(flax.linen.Module):
        @flax.linen.compact
        def forward(self, input1, input2, deterministic):
            ...
            return outputs

    def my_flax_module_wrapper(module, inputs, training):
        input1, input2 = inputs
        return module.forward(input1, input2, not training)

    flax_module = MyFlaxModule()
    keras_layer = FlaxLayer(
        module=flax_module,
        method=my_flax_module_wrapper,
    )
    ```

    Args:
        module: An instance of `flax.linen.Module` or subclass.
        method: The method to call the model. This is generally a method in the
            `Module`. If not provided, the `__call__` method is used. `method`
            can also be a function not defined in the `Module`, in which case it
            must take the `Module` as the first argument. It is used for both
            `Module.init` and `Module.apply`. Details are documented in the
            `method` argument of [`flax.linen.Module.apply()`](
              https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.apply).
        variables: A `dict` containing all the variables of the module in the
            same format as what is returned by [`flax.linen.Module.init()`](
              https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.init).
            It should contain a "params" key and, if applicable, other keys for
            collections of variables for non-trainable state. This allows
            passing trained parameters and learned non-trainable state or
            controlling the initialization. If `None` is passed, the module's
            `init` function is called at build time to initialize the variables
            of the model.
    """

    def __init__(
        self,
        module,
        method=None,
        variables=None,
        **kwargs,
    ):
        # Late import to only require Flax when this is used.
        from flax.core import scope as flax_scope

        if backend.backend() != "jax":
            raise ValueError(
                "FlaxLayer is only supported with the JAX backend. Current "
                f"backend: {backend.backend()}"
            )

        self.module = module
        self.method = method

        apply_mutable = flax_scope.DenyList(["params"])

        def apply_with_training(params, state, rng, inputs, training):
            return self.module.apply(
                self._params_and_state_to_variables(params, state),
                inputs,
                rngs=rng,
                method=self.method,
                mutable=apply_mutable,
                training=training,
            )

        def apply_without_training(params, state, rng, inputs):
            return self.module.apply(
                self._params_and_state_to_variables(params, state),
                inputs,
                rngs=rng,
                method=self.method,
                mutable=apply_mutable,
            )

        def init_with_training(rng, inputs, training):
            return self._variables_to_params_and_state(
                self.module.init(
                    rng,
                    inputs,
                    method=self.method,
                    training=training,
                )
            )

        def init_without_training(rng, inputs):
            return self._variables_to_params_and_state(
                self.module.init(
                    rng,
                    inputs,
                    method=self.method,
                )
            )

        if (
            "training"
            in inspect.signature(method or module.__call__).parameters
        ):
            call_fn, init_fn = apply_with_training, init_with_training
        else:
            call_fn, init_fn = apply_without_training, init_without_training

        params, state = self._variables_to_params_and_state(variables)

        super().__init__(
            call_fn=call_fn,
            init_fn=init_fn,
            params=params,
            state=state,
            **kwargs,
        )

    def _params_and_state_to_variables(self, params, state):
        if params:
            if state:
                return {**params, **state}
            else:
                return params
        elif state:
            return state
        return {}

    def _variables_to_params_and_state(self, variables):
        # neither params nor state
        if variables is None:
            return None, None
        # state only
        if "params" not in variables:
            return {}, variables
        # params only
        if len(variables) == 1:
            return variables, {}
        # both, we need to split
        params = {"params": variables["params"]}
        state = {k: v for k, v in variables.items() if k != "params"}
        return params, state

    def _get_init_rng(self):
        return {
            "params": self.seed_generator.next(),
            "dropout": self.seed_generator.next(),
        }

    def _get_call_rng(self, training):
        if training:
            return {"dropout": self.seed_generator.next()}
        else:
            return {}

    def get_config(self):
        config_method = self.method
        if (
            hasattr(self.method, "__self__")
            and self.method.__self__ == self.module
        ):
            # A method bound to the module is serialized by name.
            config_method = self.method.__name__
        config = {
            "module": serialization_lib.serialize_keras_object(self.module),
            "method": serialization_lib.serialize_keras_object(config_method),
        }
        base_config = super().get_config()
        # call_fn and init_fn come from module, do not save them.
        base_config.pop("call_fn")
        base_config.pop("init_fn")
        return dict(list(base_config.items()) + list(config.items()))

    @classmethod
    def from_config(cls, config):
        module = serialization_lib.deserialize_keras_object(config["module"])
        method = serialization_lib.deserialize_keras_object(config["method"])
        if isinstance(config["method"], str):
            # Deserialize bound method from the module.
            method = getattr(module, method)
        config["module"] = module
        config["method"] = method
        return cls(**config)
