"""Wrapper layer to apply every temporal slice of an input."""

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.core.wrapper import Wrapper
from keras.src.layers.layer import Layer


@keras_export("keras.layers.TimeDistributed")
class TimeDistributed(Wrapper):
    """This wrapper allows to apply a layer to every temporal slice of an input.

    Every input should be at least 3D, and the dimension of index one of the
    first input will be considered to be the temporal dimension.

    Consider a batch of 32 video samples, where each sample is a 128x128 RGB
    image with `channels_last` data format, across 10 timesteps.
    The batch input shape is `(32, 10, 128, 128, 3)`.

    You can then use `TimeDistributed` to apply the same `Conv2D` layer to each
    of the 10 timesteps, independently:

    >>> inputs = layers.Input(shape=(10, 128, 128, 3), batch_size=32)
    >>> conv_2d_layer = layers.Conv2D(64, (3, 3))
    >>> outputs = layers.TimeDistributed(conv_2d_layer)(inputs)
    >>> outputs.shape
    (32, 10, 126, 126, 64)

    Because `TimeDistributed` applies the same instance of `Conv2D` to each of
    the timestamps, the same set of weights are used at each timestamp.

    Args:
        layer: a `keras.layers.Layer` instance.

    Call arguments:
        inputs: Input tensor of shape (batch, time, ...) or nested tensors,
            and each of which has shape (batch, time, ...).
        training: Python boolean indicating whether the layer should behave in
            training mode or in inference mode. This argument is passed to the
            wrapped layer (only if the layer supports this argument).
        mask: Binary tensor of shape `(samples, timesteps)` indicating whether
            a given timestep should be masked. This argument is passed to the
            wrapped layer (only if the layer supports this argument).
    """

    def __init__(self, layer, **kwargs):
        if not isinstance(layer, Layer):
            raise ValueError(
                "Please initialize `TimeDistributed` layer with a "
                f"`keras.layers.Layer` instance. Received: {layer}"
            )
        super().__init__(layer, **kwargs)
        self.supports_masking = True

    def _get_child_input_shape(self, input_shape):
        if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3:
            raise ValueError(
                "`TimeDistributed` Layer should be passed an `input_shape` "
                f"with at least 3 dimensions, received: {input_shape}"
            )
        return (input_shape[0], *input_shape[2:])

    def compute_output_shape(self, input_shape):
        child_input_shape = self._get_child_input_shape(input_shape)
        child_output_shape = self.layer.compute_output_shape(child_input_shape)
        return (child_output_shape[0], input_shape[1], *child_output_shape[1:])

    def build(self, input_shape):
        child_input_shape = self._get_child_input_shape(input_shape)
        super().build(child_input_shape)

    def call(self, inputs, training=None, mask=None):
        input_shape = ops.shape(inputs)
        mask_shape = None if mask is None else ops.shape(mask)
        batch_size = input_shape[0]
        timesteps = input_shape[1]

        # For TF backend with graph mode and `partial_batch_size`, skip
        # evaluation of `batch_size` as it can be a `strided_slice` and
        # not a constant.
        if backend.backend() == "tensorflow":
            from keras.src.utils.module_utils import tensorflow as tf

            if (
                not tf.executing_eagerly
                and mask_shape is not None
                and mask_shape[1:2] != (timesteps,)
            ):
                raise ValueError(
                    "`TimeDistributed` Layer should be passed a `mask` of "
                    f"shape ({batch_size}, {timesteps}, ...), "
                    f"received: mask.shape={mask_shape}"
                )
        elif mask_shape is not None and mask_shape[:2] != (
            batch_size,
            timesteps,
        ):
            raise ValueError(
                "`TimeDistributed` Layer should be passed a `mask` of "
                f"shape ({batch_size}, {timesteps}, ...), "
                f"received: mask.shape={mask_shape}"
            )

        def time_distributed_transpose(data):
            """Swaps the timestep and batch dimensions of a tensor."""
            axes = [1, 0, *range(2, len(data.shape))]
            return ops.transpose(data, axes=axes)

        inputs = time_distributed_transpose(inputs)
        if mask is not None:
            mask = time_distributed_transpose(mask)

        def step_function(i):
            kwargs = {}
            if self.layer._call_has_mask_arg and mask is not None:
                kwargs["mask"] = mask[i]
            if self.layer._call_has_training_arg:
                kwargs["training"] = training
            return self.layer.call(inputs[i], **kwargs)

        # Implementation #1: is the time axis is static, use a Python for loop.

        if inputs.shape[0] is not None:
            outputs = ops.stack(
                [step_function(i) for i in range(inputs.shape[0])]
            )
            return time_distributed_transpose(outputs)

        # Implementation #2: use backend.vectorized_map.

        outputs = backend.vectorized_map(step_function, ops.arange(timesteps))
        return time_distributed_transpose(outputs)
